diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index 90250c5d9ca93..c69c48779f5b3 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -354,7 +354,7 @@ - + diff --git a/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md index 4b08ca6035e40..b917429847d20 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md +++ b/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md @@ -6,6 +6,7 @@ - Added support for service API version `7.6-preview.1`. - Added new methods `StartPreRestoreAsync`, `StartPreRestore`, `StartPreBackupAsync`, and `StartPreBackupAsync` to the `KeyVaultBackupClient`. +- Support for Continuous Access Evaluation (CAE). ### Breaking Changes diff --git a/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs new file mode 100644 index 0000000000000..e086160b485db --- /dev/null +++ b/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Azure.Security.KeyVault.Tests; +using NUnit.Framework; + +namespace Azure.Security.KeyVault.Administration.Tests +{ + [NonParallelizable] + internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase + { + [SetUp] + public void Setup() + { + ChallengeBasedAuthenticationPolicy.ClearCache(); + } + + [Test] + [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")] + public async Task VerifyCaeClaims(string challenge, string expectedClaims) + { + int callCount = 0; + + MockResponse response = new MockResponse(200); + + MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: response); + + var credential = new TokenCredentialStub((r, c) => + { + if (callCount == 0) + { + // The first challenge should not have any claims. + Assert.IsNull(r.Claims); + } + else if (callCount == 1) + { + Assert.AreEqual(expectedClaims, r.Claims); + } + else + { + Assert.Fail("unexpected token request"); + } + Interlocked.Increment(ref callCount); + Assert.AreEqual(true, r.IsCaeEnabled); + + return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2)); + }, true); + + KeyVaultBackupClient client = new( + VaultUri, + credential, + new KeyVaultAdministrationClientOptions() + { + Transport = transport, + }); + + try + { + KeyVaultBackupOperation operation = await client.StartBackupAsync(VaultUri); + } + catch (RequestFailedException ex) + { + Assert.AreEqual(200, ex.Status); + return; + } + catch (Exception ex) + { + Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}"); + return; + } + } + + [Test] + public void ThrowsWithTwoConsecutiveCaeChallenges() + { + MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2); + + MockTransport credentialTransport = GetMockCredentialTransport(2); + + KeyVaultBackupClient client = new( + VaultUri, + new MockCredential(credentialTransport), + new KeyVaultAdministrationClientOptions() + { + Transport = keyVaultTransport, + }); + + try + { + var operation = client.StartBackup(VaultUri); + } + catch (RequestFailedException ex) + { + Assert.AreEqual(401, ex.Status); + return; + } + catch (Exception ex) + { + Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}"); + return; + } + Assert.Fail("Expected RequestFailedException, but no exception was thrown."); + } + } +} diff --git a/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md index 917ca4d22e1da..35ad51f55b5a5 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md +++ b/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md @@ -3,6 +3,7 @@ ## 4.7.0-beta.1 (Unreleased) ### Features Added +- Support for Continuous Access Evaluation (CAE). ### Breaking Changes diff --git a/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs new file mode 100644 index 0000000000000..27fca350ee8db --- /dev/null +++ b/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Azure.Security.KeyVault.Tests; +using NUnit.Framework; + +namespace Azure.Security.KeyVault.Certificates.Tests +{ + [NonParallelizable] + internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase + { + [SetUp] + public void Setup() + { + ChallengeBasedAuthenticationPolicy.ClearCache(); + } + + [Test] + [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")] + public async Task VerifyCaeClaims(string challenge, string expectedClaims) + { + int callCount = 0; + + MockResponse responseWithSecret = new MockResponse(200) + .WithContent(@"{ + ""id"": ""https://foo.vault.azure.net/certificates/1/foo"", + ""cer"": ""Zm9v"", + ""attributes"": { + }, + ""pending"": { + ""id"": ""foo"" + } + }"); + + MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithSecret); + + var credential = new TokenCredentialStub((r, c) => + { + if (callCount == 0) + { + // The first challenge should not have any claims. + Assert.IsNull(r.Claims); + } + else if (callCount == 1) + { + Assert.AreEqual(expectedClaims, r.Claims); + } + else + { + Assert.Fail("unexpected token request"); + } + Interlocked.Increment(ref callCount); + Assert.AreEqual(true, r.IsCaeEnabled); + + return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2)); + }, true); + + CertificateClient client = new( + VaultUri, + credential, + new CertificateClientOptions() + { + Transport = transport, + }); + + Response response = await client.GetCertificateAsync("certificate"); + Assert.AreEqual(200, response.GetRawResponse().Status); + } + + [Test] + public void ThrowsWithTwoConsecutiveCaeChallenges() + { + MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2); + + MockTransport credentialTransport = GetMockCredentialTransport(2); + + CertificateClient client = new( + VaultUri, + new MockCredential(credentialTransport), + new CertificateClientOptions() + { + Transport = keyVaultTransport, + }); + + try + { + client.GetCertificate("certificate"); + } + catch (RequestFailedException ex) + { + Assert.AreEqual(401, ex.Status); + return; + } + catch (Exception ex) + { + Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}"); + return; + } + Assert.Fail("Expected RequestFailedException, but no exception was thrown."); + } + } +} diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md index 1734a9546135d..1b20dcc967a67 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md +++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md @@ -3,6 +3,7 @@ ## 4.7.0-beta.1 (Unreleased) ### Features Added +- Support for Continuous Access Evaluation (CAE). ### Breaking Changes diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs new file mode 100644 index 0000000000000..6f227b954781c --- /dev/null +++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Azure.Security.KeyVault.Tests; +using NUnit.Framework; + +namespace Azure.Security.KeyVault.Keys.Tests +{ + [NonParallelizable] + internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase + { + [SetUp] + public void Setup() + { + ChallengeBasedAuthenticationPolicy.ClearCache(); + } + + [Test] + [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")] + public async Task VerifyCaeClaims(string challenge, string expectedClaims) + { + int callCount = 0; + + MockResponse responseWithKey = new MockResponse(200) + .WithContent(@"{ + ""key"": { + ""kid"": ""https://heathskeyvault.vault.azure.net/keys/625710934/ef3685592e1c4e839206aaa10f0f058e"", + ""kty"": ""RSA"", + ""key_ops"": [ + ""encrypt"", + ""decrypt"", + ""sign"", + ""verify"", + ""wrapKey"", + ""unwrapKey"" + ], + ""n"": ""foo"", + ""e"": ""AQAB"" + }, + ""attributes"": { + ""enabled"": true, + ""created"": 1613807137, + ""updated"": 1613807137, + ""recoveryLevel"": ""Recoverable\u002BPurgeable"", + ""recoverableDays"": 90 + } + }"); + + MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithKey); + + var credential = new TokenCredentialStub((r, c) => + { + if (callCount == 0) + { + // The first challenge should not have any claims. + Assert.IsNull(r.Claims); + } + else if (callCount == 1) + { + Assert.AreEqual(expectedClaims, r.Claims); + } + else + { + Assert.Fail("unexpected token request"); + } + Interlocked.Increment(ref callCount); + Assert.AreEqual(true, r.IsCaeEnabled); + + return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2)); + }, true); + + KeyClient client = new( + VaultUri, + credential, + new KeyClientOptions() + { + Transport = transport, + }); + + Response response = await client.GetKeyAsync("key"); + Assert.AreEqual(200, response.GetRawResponse().Status); + } + + [Test] + public void ThrowsWithTwoConsecutiveCaeChallenges() + { + MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2); + + MockTransport credentialTransport = GetMockCredentialTransport(2); + + KeyClient client = new( + VaultUri, + new MockCredential(credentialTransport), + new KeyClientOptions() + { + Transport = keyVaultTransport, + }); + + try + { + client.GetKey("key"); + } + catch (RequestFailedException ex) + { + Assert.AreEqual(401, ex.Status); + return; + } + catch (Exception ex) + { + Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}"); + return; + } + Assert.Fail("Expected RequestFailedException, but no exception was thrown."); + } + } +} diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md index 6e2354363b267..07b929203c42c 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md +++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md @@ -3,6 +3,7 @@ ## 4.7.0-beta.1 (Unreleased) ### Features Added +- Support for Continuous Access Evaluation (CAE). ### Breaking Changes diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs index 50c53bc6aed8f..7330a6ad028ae 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs +++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs @@ -199,6 +199,18 @@ public async Task ReauthenticatesWhenTenantChanged() Assert.AreEqual("secret-value", response.Value.Value); } + [Test] + public void GetClaimsFromChallengeHeaders() + { + MockResponse response401WithClaims = new MockResponse(401) + .WithHeader("WWW-Authenticate", @"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsiYWNycyI6eyJlc3NlbnRpYWwiOnRydWUsInZhbHVlIjoiY3AxIn19fQ=="""); + Assert.AreEqual(ChallengeBasedAuthenticationPolicy.getDecodedClaimsParameter("insufficient_claims", response401WithClaims), @"{""access_token"":{""acrs"":{""essential"":true,""value"":""cp1""}}}"); + + MockResponse response401 = new MockResponse(401) + .WithHeader("WWW-Authenticate", @"Bearer authorization=""https://login.windows.net/de763a21-49f7-4b08-a8e1-52c8fbc103b4"", resource=""https://vault.azure.net"""); + Assert.IsNull(ChallengeBasedAuthenticationPolicy.getDecodedClaimsParameter(null, response401)); + } + private class MockTransportBuilder { private const string AuthorizationHeader = "Authorization"; diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs new file mode 100644 index 0000000000000..dad44a925d3e0 --- /dev/null +++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Azure.Security.KeyVault.Tests; +using NUnit.Framework; + +namespace Azure.Security.KeyVault.Secrets.Tests +{ + [NonParallelizable] + internal class ContinousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase + { + [SetUp] + public void Setup() + { + ChallengeBasedAuthenticationPolicy.ClearCache(); + } + + [Test] + [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")] + public async Task VerifyCaeClaims(string challenge, string expectedClaims) + { + int callCount = 0; + + MockResponse responseWithSecret = new MockResponse(200) + { + ContentStream = new KeyVaultSecret("test-secret", "secret-value").ToStream(), + }; + + MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithSecret); + + var credential = new TokenCredentialStub((r, c) => + { + if (callCount == 0) + { + // The first challenge should not have any claims. + Assert.IsNull(r.Claims); + } + else if (callCount == 1) + { + Assert.AreEqual(expectedClaims, r.Claims); + } + else + { + Assert.Fail("unexpected token request"); + } + Interlocked.Increment(ref callCount); + Assert.AreEqual(true, r.IsCaeEnabled); + + return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2)); + }, true); + + SecretClient client = new( + VaultUri, + credential, + new SecretClientOptions() + { + Transport = transport, + }); + + Response response = await client.GetSecretAsync("test-secret"); + Assert.AreEqual(200, response.GetRawResponse().Status); + Assert.AreEqual("secret-value", response.Value.Value); + } + + [Test] + public void ThrowsWithTwoConsecutiveCaeChallenges() + { + MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2); + + MockTransport credentialTransport = GetMockCredentialTransport(2); + + SecretClient client = new( + VaultUri, + new MockCredential(credentialTransport), + new SecretClientOptions() + { + Transport = keyVaultTransport, + }); + + try + { + client.GetSecret("test-secret"); + } + catch (RequestFailedException ex) + { + Assert.AreEqual(401, ex.Status); + return; + } + catch (Exception ex) + { + Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}"); + return; + } + Assert.Fail("Expected RequestFailedException, but no exception was thrown."); + } + + [Test] + public void ensureTokenFromClaimsChallengeGetsUsed() + { + MockTransport transport = new(new[] + { + defaultInitialChallenge, + + new MockResponse(200) + .WithJson(""" + { + "token_type": "Bearer", + "expires_in": 3599, + "resource": "https://vault.azure.net", + "access_token": "TOKEN_1" + } + """), + + new MockResponse(200) + { + ContentStream = new KeyVaultSecret("test-secret", "secret-value").ToStream(), + }, + + defaultCaeChallenge, + + new MockResponse(200) + .WithJson(""" + { + "token_type": "Bearer", + "expires_in": 3599, + "resource": "https://vault.azure.net", + "access_token": "TOKEN_2" + } + """), + + new MockResponse(200) + { + ContentStream = new KeyVaultSecret("test-secret2", "secret-value").ToStream(), + }, + }); + + SecretClient client = new( + VaultUri, + new MockCredential(transport), + new SecretClientOptions() + { + Transport = transport, + }); + + _ = client.GetSecret("test-secret"); + _ = client.GetSecret("test-secret2"); + + Assert.IsTrue(transport.Requests[2].Headers.TryGetValue("Authorization", out string authorizationValue)); + Assert.AreEqual("Bearer TOKEN_1", authorizationValue); + + Assert.IsTrue(transport.Requests[5].Headers.TryGetValue("Authorization", out string authorizationValue2)); + Assert.AreEqual("Bearer TOKEN_2", authorizationValue2); + } + } +} diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs b/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs index ec82ee2c1d630..0c52c7ba688fb 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs +++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Concurrent; using System.Globalization; +using System.Net; using System.Threading.Tasks; namespace Azure.Security.KeyVault @@ -51,7 +52,7 @@ private async ValueTask AuthorizeRequestInternal(HttpMessage message, bool async if (_challenge != null) { // We fetched the challenge from the cache, but we have not initialized the Scopes in the base yet. - var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId); + var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId, isCaeEnabled: true); if (async) { await AuthenticateAndAuthorizeRequestAsync(message, context).ConfigureAwait(false); @@ -84,6 +85,29 @@ protected override ValueTask AuthorizeRequestOnChallengeAsync(HttpMessage protected override bool AuthorizeRequestOnChallenge(HttpMessage message) => AuthorizeRequestOnChallengeAsyncInternal(message, false).EnsureCompleted(); + /// + /// Gets the claims parameter from the challenge response. + /// If there are no claims, returns null. + /// + /// The error message from the service. + /// The response from the service which contains the headers. + /// A string with the decoded claims if present, otherwise null + internal static string getDecodedClaimsParameter(string error, Response response) + { + // According to docs https://learn.microsoft.com/en-us/entra/identity-platform/claims-challenge?tabs=dotnet#claims-challenge-header-format, + // the error message must be "insufficient_claims" when a claims challenge should be generated. + if (error == "insufficient_claims") + { + return AuthorizationChallengeParser.GetChallengeParameterFromResponse(response, "Bearer", "claims") switch + { + { Length: 0 } => null, + string enc => System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(enc)) + }; + } + + return null; + } + private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessage message, bool async) { if (message.Request.Content == null && message.TryGetProperty(KeyVaultStashedContentKey, out var content)) @@ -91,8 +115,10 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa message.Request.Content = content as RequestContent; } + string error = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "error"); string authority = GetRequestAuthority(message.Request); string scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "resource"); + if (scope != null) { scope += "/.default"; @@ -102,6 +128,15 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "scope"); } + // Handle CAE Challenges + string claims = getDecodedClaimsParameter(error, message.Response); + if (claims != null) + { + // Get the scope from the cache + s_challengeCache.TryGetValue(authority, out _challenge); + scope = _challenge.Scopes[0]; + } + if (scope is null) { if (s_challengeCache.TryGetValue(authority, out _challenge)) @@ -140,7 +175,7 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa s_challengeCache[authority] = _challenge; } - var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId); + var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId, isCaeEnabled: true, claims: claims); if (async) { await AuthenticateAndAuthorizeRequestAsync(message, context).ConfigureAwait(false); @@ -153,6 +188,81 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa return true; } + /// + public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + return ProcessAsyncInternal(message, pipeline, true); + } + + /// + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + ProcessAsyncInternal(message, pipeline, false).EnsureCompleted(); + } + + private async ValueTask ProcessAsyncInternal(HttpMessage message, ReadOnlyMemory pipeline, bool async) + { + if (message.Request.Uri.Scheme != Uri.UriSchemeHttps) + { + throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints."); + } + + if (async) + { + await AuthorizeRequestAsync(message).ConfigureAwait(false); + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + else + { + AuthorizeRequest(message); + ProcessNext(message, pipeline); + } + + // Check if we have received a challenge or we have not yet issued the first request. + if (message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(HttpHeader.Names.WwwAuthenticate)) + { + // Attempt to get the TokenRequestContext based on the challenge. + // If we fail to get the context, the challenge was not present or invalid. + // If we succeed in getting the context, authenticate the request and pass it up the policy chain. + if (async) + { + if (await AuthorizeRequestOnChallengeAsync(message).ConfigureAwait(false)) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + } + else + { + if (AuthorizeRequestOnChallenge(message)) + { + ProcessNext(message, pipeline); + } + } + + // Handle the scenario in which we get a CAE challenge back. + if (message.Response.Status == (int)HttpStatusCode.Unauthorized + && message.Response.Headers.Contains(HttpHeader.Names.WwwAuthenticate) + && AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "claims") != null) + { + if (async) + { + if (await AuthorizeRequestOnChallengeAsync(message).ConfigureAwait(false)) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + } + else + { + if (AuthorizeRequestOnChallenge(message)) + { + ProcessNext(message, pipeline); + } + } + } + // If we get a second CAE challenge, an unlikely scenario, we do not attempt to re-authenticate. + } + } + internal class ChallengeParameters { internal ChallengeParameters(Uri authorizationUri, string[] scopes) diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems index 9f553875f304a..55413c17628b9 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems +++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems @@ -9,6 +9,7 @@ Azure.Security.KeyVault.Tests + diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs new file mode 100644 index 0000000000000..af2ee53ff9139 --- /dev/null +++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using System.Threading; +using Azure.Core.TestFramework; +using Azure.Core; +using System.IO; +using System.Text.Json; +using Azure.Core.Pipeline; + +namespace Azure.Security.KeyVault.Tests +{ + internal class ContinuousAccessEvaluationTestsBase + { + protected MockResponse defaultCaeChallenge = new MockResponse(401).WithHeader("WWW-Authenticate", @"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="""); + + protected MockResponse defaultInitialChallenge = new MockResponse(401).WithHeader("WWW-Authenticate", @"Bearer authorization=""https://login.windows.net/de763a21-49f7-4b08-a8e1-52c8fbc103b4"", resource=""https://vault.azure.net"""); + + private const string VaultHost = "test.vault.azure.net"; + protected Uri VaultUri => new Uri("https://" + VaultHost); + + protected MockTransport GetMockTransportWithCaeChallenges(int numberOfCaeChallenges = 1, MockResponse final200response = null ) + { + if (numberOfCaeChallenges < 1) + { + throw new ArgumentOutOfRangeException(nameof(numberOfCaeChallenges), "Number of CAE challenges must be greater than or equal to 1."); + } + + var responses = new List { defaultInitialChallenge }; + for (int i = 0; i < numberOfCaeChallenges; i++) + { + responses.Add(defaultCaeChallenge); + } + if (final200response != null) + { + responses.Add(final200response); + } + return new MockTransport(responses.ToArray()); + } + + protected MockTransport GetMockCredentialTransport(int numberOfTokenResponses = 1) + { + if (numberOfTokenResponses < 1) + { + throw new ArgumentOutOfRangeException(nameof(numberOfTokenResponses), "Number of token responses must be greater than or equal to 1."); + } + + var responses = new List(); + for (int i = 0; i < numberOfTokenResponses; i++) + { + responses.Add(new MockResponse(200) + .WithJson(""" + { + "token_type": "Bearer", + "expires_in": 3599, + "resource": "https://vault.azure.net", + "access_token": "foo" + } + """)); + } + return new MockTransport(responses.ToArray()); + } + + protected class TokenCredentialStub : TokenCredential + { + public TokenCredentialStub() { } + + public TokenCredentialStub(Func handler, bool isAsync) + { + setCallBack(handler, isAsync); + } + + private Func> _getTokenAsyncHandler; + private Func _getTokenHandler; + + public void setCallBack(Func handler, bool isAsync) + { + if (isAsync) + { +#pragma warning disable 1998 + _getTokenAsyncHandler = async (r, c) => handler(r, c); +#pragma warning restore 1998 + } + else + { + _getTokenHandler = handler; + } + } + + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + => _getTokenAsyncHandler(requestContext, cancellationToken); + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + => _getTokenHandler(requestContext, cancellationToken); + } + + protected class MockCredential : TokenCredential + { + private readonly HttpPipeline _pipeline; + private readonly string _tenantId; + private readonly string _clientId; + private readonly string _clientSecret; + private const string TenantId = "72f988bf-86f1-41af-91ab-2d7cd011db47"; + + public MockCredential(MockTransport transport, string tenantId = TenantId, string clientId = "test_id", string clientSecret = "test_secret") + { + _pipeline = new HttpPipeline(transport); + _tenantId = tenantId ?? throw new ArgumentNullException(nameof(tenantId)); + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); + } + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) => GetTokenAsync(requestContext, cancellationToken).EnsureCompleted(); + + public override async ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + Request request = _pipeline.CreateRequest(); + request.Method = RequestMethod.Post; + request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType); + + request.Uri.Reset(new Uri($"https://login.windows.net/{_tenantId}/oauth2/v2.0/token")); + + string body = $"response_type=token&grant_type=client_credentials&client_id={Uri.EscapeDataString(_clientId)}&client_secret={Uri.EscapeDataString(_clientSecret)}&scope={Uri.EscapeDataString(string.Join(" ", requestContext.Scopes))}"; + ReadOnlyMemory content = Encoding.UTF8.GetBytes(body).AsMemory(); + request.Content = RequestContent.Create(content); + + Response response = await _pipeline.SendRequestAsync(request, cancellationToken); + if (response.Status == 200 || response.Status == 201) + { + return await DeserializeAsync(response.ContentStream, cancellationToken); + } + + throw new RequestFailedException(response.Status, response.ReasonPhrase); + } + + private static async Task DeserializeAsync(Stream content, CancellationToken cancellationToken) + { + using (JsonDocument json = await JsonDocument.ParseAsync(content, default, cancellationToken).ConfigureAwait(false)) + { + return Deserialize(json.RootElement); + } + } + + private static AccessToken Deserialize(JsonElement json) + { + string accessToken = null; + DateTimeOffset expiresOn = DateTimeOffset.MaxValue; + + foreach (JsonProperty prop in json.EnumerateObject()) + { + switch (prop.Name) + { + case "access_token": + accessToken = prop.Value.GetString(); + break; + + case "expires_in": + expiresOn = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(prop.Value.GetInt64()); + break; + } + } + + return new AccessToken(accessToken, expiresOn); + } + } + } +}