diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index be8ac99162984..4420a80268b03 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -175,7 +175,7 @@ - + @@ -202,10 +202,10 @@ - + - + @@ -232,7 +232,7 @@ - + diff --git a/sdk/core/Azure.Core.TestFramework/src/Azure.Core.TestFramework.csproj b/sdk/core/Azure.Core.TestFramework/src/Azure.Core.TestFramework.csproj index 44530c87fad0e..68aed12efc9f2 100644 --- a/sdk/core/Azure.Core.TestFramework/src/Azure.Core.TestFramework.csproj +++ b/sdk/core/Azure.Core.TestFramework/src/Azure.Core.TestFramework.csproj @@ -6,6 +6,7 @@ + diff --git a/sdk/core/Azure.Core/tests/TestServer.cs b/sdk/core/Azure.Core.TestFramework/src/TestServer.cs similarity index 95% rename from sdk/core/Azure.Core/tests/TestServer.cs rename to sdk/core/Azure.Core.TestFramework/src/TestServer.cs index e884750e55cb9..2c87078740c77 100644 --- a/sdk/core/Azure.Core/tests/TestServer.cs +++ b/sdk/core/Azure.Core.TestFramework/src/TestServer.cs @@ -12,7 +12,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; -namespace Azure.Core.Tests +namespace Azure.Core.TestFramework { public class TestServer : IStartup, IDisposable { @@ -21,7 +21,7 @@ public class TestServer : IStartup, IDisposable public Uri Address => new Uri(_host.ServerFeatures.Get().Addresses.First()); - public TestServer(Action app) : this(context => { app(context); return Task.CompletedTask;}) + public TestServer(Action app) : this(context => { app(context); return Task.CompletedTask; }) { } diff --git a/sdk/core/Azure.Core/tests/HttpWebRequestTransportFunctionalTest.cs b/sdk/core/Azure.Core/tests/HttpWebRequestTransportFunctionalTest.cs index 0fa82ef39a343..201e60452c20c 100644 --- a/sdk/core/Azure.Core/tests/HttpWebRequestTransportFunctionalTest.cs +++ b/sdk/core/Azure.Core/tests/HttpWebRequestTransportFunctionalTest.cs @@ -5,6 +5,7 @@ using System.Net.Http; using System.Threading.Tasks; using Azure.Core.Pipeline; +using Azure.Core.TestFramework; using NUnit.Framework; namespace Azure.Core.Tests diff --git a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs index f8204aa562a58..cee0f2d843c23 100644 --- a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs +++ b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using Azure.Core; namespace Azure.Identity @@ -25,7 +26,14 @@ public virtual TokenCredential CreateEnvironmentCredential() public virtual TokenCredential CreateManagedIdentityCredential(string clientId) { - return new ManagedIdentityCredential(clientId, Pipeline); + return new ManagedIdentityCredential(new ManagedIdentityClient( + new ManagedIdentityClientOptions + { + ClientId = clientId, + Pipeline = Pipeline, + InitialImdsConnectionTimeout = TimeSpan.FromSeconds(1) + }) + ); } public virtual TokenCredential CreateSharedTokenCacheCredential(string tenantId, string username) diff --git a/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs index 538da3d322e03..31b5e1d7a0798 100644 --- a/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs +++ b/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs @@ -29,9 +29,12 @@ internal class ImdsManagedIdentitySource : ManagedIdentitySource private readonly string _clientId; private readonly Uri _imdsEndpoint; - internal ImdsManagedIdentitySource(CredentialPipeline pipeline, string clientId) : base(pipeline) + private TimeSpan? _imdsNetworkTimeout; + + internal ImdsManagedIdentitySource(ManagedIdentityClientOptions options) : base(options.Pipeline) { - _clientId = clientId; + _clientId = options.ClientId; + _imdsNetworkTimeout = options.InitialImdsConnectionTimeout; if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint)) { @@ -66,6 +69,15 @@ protected override Request CreateRequest(string[] scopes) return request; } + protected override HttpMessage CreateHttpMessage(Request request) + { + HttpMessage message = base.CreateHttpMessage(request); + + message.NetworkTimeout = _imdsNetworkTimeout; + + return message; + } + public async override ValueTask AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken) { try @@ -88,6 +100,9 @@ public async override ValueTask AuthenticateAsync(bool async, Token protected override async ValueTask HandleResponseAsync(bool async, TokenRequestContext context, Response response, CancellationToken cancellationToken) { + // if we got a response from IMDS we can stop limiting the network timeout + _imdsNetworkTimeout = null; + // handle error status codes indicating managed identity is not available var baseMessage = response.Status switch { diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs index b2f9f6090308d..2dd1e4ae82763 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs @@ -47,7 +47,7 @@ private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentity AzureArcManagedIdentitySource.TryCreate(options) ?? ServiceFabricManagedIdentitySource.TryCreate(options) ?? TokenExchangeManagedIdentitySource.TryCreate(options) ?? - new ImdsManagedIdentitySource(options.Pipeline, options.ClientId); + new ImdsManagedIdentitySource(options); } } } diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityClientOptions.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityClientOptions.cs index 9ca051cac54d1..b762e7b7771cc 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityClientOptions.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityClientOptions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; + namespace Azure.Identity { internal class ManagedIdentityClientOptions @@ -11,6 +13,8 @@ internal class ManagedIdentityClientOptions public bool PreserveTransport { get; set; } + public TimeSpan? InitialImdsConnectionTimeout { get; set; } + public CredentialPipeline Pipeline { get; set; } } } diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/ManagedIdentitySource.cs index b7a933581d54b..050519ffcb673 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentitySource.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentitySource.cs @@ -27,7 +27,7 @@ protected ManagedIdentitySource(CredentialPipeline pipeline) public virtual async ValueTask AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken) { using Request request = CreateRequest(context.Scopes); - using HttpMessage message = new HttpMessage(request, _responseClassifier); + using HttpMessage message = CreateHttpMessage(request); if (async) { await Pipeline.HttpPipeline.SendAsync(message, cancellationToken).ConfigureAwait(false); @@ -58,9 +58,14 @@ protected virtual async ValueTask HandleResponseAsync(bool async, T protected abstract Request CreateRequest(string[] scopes); + protected virtual HttpMessage CreateHttpMessage(Request request) + { + return new HttpMessage(request, _responseClassifier); + } + protected static async Task GetMessageFromResponse(Response response, bool async, CancellationToken cancellationToken) { - if (response?.ContentStream == null) + if (response?.ContentStream == null || !response.ContentStream.CanRead || response.ContentStream.Length == 0) { return null; } diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs index 16fce76ad0cb5..86f3e4423daa6 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs @@ -6,10 +6,12 @@ using System.IO; using System.Net.Http; using System.Text; +using System.Threading; using System.Threading.Tasks; using Azure.Core; using Azure.Core.TestFramework; using Azure.Identity.Tests.Mock; +using Microsoft.AspNetCore.Http; using NUnit.Framework; namespace Azure.Identity.Tests @@ -318,7 +320,7 @@ public async Task VerifyMsiUnavailableOnIMDSAggregateExcpetion() using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } }); // setting the delay to 1ms and retry mode to fixed to speed up test - var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(1), Mode = RetryMode.Fixed } }; + var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(1), Mode = RetryMode.Fixed, NetworkTimeout = TimeSpan.FromMilliseconds(100) } }; var credential = InstrumentClient(new ManagedIdentityCredential(options: options)); @@ -335,7 +337,7 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion() { using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } }); - var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0 } }; + var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100) } }; var credential = InstrumentClient(new ManagedIdentityCredential(options: options)); @@ -346,6 +348,101 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion() await Task.CompletedTask; } + [NonParallelizable] + [Test] + public async Task VerifyMsiUnavailableOnIMDSGatewayErrorResponse([Values(502, 504)]int statusCode) + { + using var server = new TestServer(context => + { + context.Response.StatusCode = statusCode; + }); + + using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } }); + + // setting the delay to 1ms and retry mode to fixed to speed up test + var options = new TokenCredentialOptions() { }; + + var credential = InstrumentClient(new ManagedIdentityCredential(options: options)); + + var ex = Assert.ThrowsAsync(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default))); + + Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.GatewayError)); + + await Task.CompletedTask; + } + + [NonParallelizable] + [Test] + public async Task VerifyInitialImdsConnectionTimeoutHonored() + { + using var server = new TestServer(async context => + { + await Task.Delay(8000); + + context.Response.StatusCode = 418; + }); + + using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } }); + + // setting the delay to 1ms and retry mode to fixed to speed up test + var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(0), Mode = RetryMode.Fixed } }; + + var pipeline = CredentialPipeline.GetInstance(options); + + var miClientOptions = new ManagedIdentityClientOptions { InitialImdsConnectionTimeout = TimeSpan.FromMilliseconds(100), Pipeline = pipeline }; + + var credential = InstrumentClient(new ManagedIdentityCredential(new ManagedIdentityClient(miClientOptions))); + + var startTime = DateTimeOffset.UtcNow; + + var ex = Assert.ThrowsAsync(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default))); + + Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.AggregateError)); + + Assert.Less(DateTimeOffset.UtcNow - startTime, TimeSpan.FromSeconds(2)); + + await Task.CompletedTask; + } + + [NonParallelizable] + [Test] + public async Task VerifyInitialImdsConnectionTimeoutRelaxed() + { + string token = Guid.NewGuid().ToString(); + int callCount = 0; + + using var server = new TestServer(async context => + { + if (Interlocked.Increment(ref callCount) > 1) + { + await Task.Delay(2000); + } + + await context.Response.WriteAsync($"{{ \"access_token\": \"{token}\", \"expires_on\": \"3600\" }}"); + }); + + using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } }); + + // setting the delay to 1ms and retry mode to fixed to speed up test + var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(0), Mode = RetryMode.Fixed } }; + + var pipeline = CredentialPipeline.GetInstance(options); + + var miClientOptions = new ManagedIdentityClientOptions { InitialImdsConnectionTimeout = TimeSpan.FromMilliseconds(1000), Pipeline = pipeline }; + + var credential = InstrumentClient(new ManagedIdentityCredential(new ManagedIdentityClient(miClientOptions))); + + var at = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); + + Assert.AreEqual(token, at.Token); + + var at2 = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); + + Assert.AreEqual(token, at.Token); + + Assert.AreEqual(2, callCount); + } + [Test] public async Task VerifyClientAuthenticateThrows() { diff --git a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Microsoft.Azure.WebJobs.Extensions.ServiceBus.Tests.csproj b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Microsoft.Azure.WebJobs.Extensions.ServiceBus.Tests.csproj index 567e33776588b..a294ab32d98a5 100644 --- a/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Microsoft.Azure.WebJobs.Extensions.ServiceBus.Tests.csproj +++ b/sdk/servicebus/Microsoft.Azure.WebJobs.Extensions.ServiceBus/tests/Microsoft.Azure.WebJobs.Extensions.ServiceBus.Tests.csproj @@ -5,8 +5,6 @@ - -