diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props
index 90250c5d9ca93..c69c48779f5b3 100644
--- a/eng/Packages.Data.props
+++ b/eng/Packages.Data.props
@@ -354,7 +354,7 @@
-
+
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md
index 4b08ca6035e40..b917429847d20 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md
+++ b/sdk/keyvault/Azure.Security.KeyVault.Administration/CHANGELOG.md
@@ -6,6 +6,7 @@
- Added support for service API version `7.6-preview.1`.
- Added new methods `StartPreRestoreAsync`, `StartPreRestore`, `StartPreBackupAsync`, and `StartPreBackupAsync` to the `KeyVaultBackupClient`.
+- Support for Continuous Access Evaluation (CAE).
### Breaking Changes
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs
new file mode 100644
index 0000000000000..e086160b485db
--- /dev/null
+++ b/sdk/keyvault/Azure.Security.KeyVault.Administration/tests/ContinuousAccessEvaluationTests.cs
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.TestFramework;
+using Azure.Security.KeyVault.Tests;
+using NUnit.Framework;
+
+namespace Azure.Security.KeyVault.Administration.Tests
+{
+ [NonParallelizable]
+ internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase
+ {
+ [SetUp]
+ public void Setup()
+ {
+ ChallengeBasedAuthenticationPolicy.ClearCache();
+ }
+
+ [Test]
+ [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")]
+ public async Task VerifyCaeClaims(string challenge, string expectedClaims)
+ {
+ int callCount = 0;
+
+ MockResponse response = new MockResponse(200);
+
+ MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: response);
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ if (callCount == 0)
+ {
+ // The first challenge should not have any claims.
+ Assert.IsNull(r.Claims);
+ }
+ else if (callCount == 1)
+ {
+ Assert.AreEqual(expectedClaims, r.Claims);
+ }
+ else
+ {
+ Assert.Fail("unexpected token request");
+ }
+ Interlocked.Increment(ref callCount);
+ Assert.AreEqual(true, r.IsCaeEnabled);
+
+ return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2));
+ }, true);
+
+ KeyVaultBackupClient client = new(
+ VaultUri,
+ credential,
+ new KeyVaultAdministrationClientOptions()
+ {
+ Transport = transport,
+ });
+
+ try
+ {
+ KeyVaultBackupOperation operation = await client.StartBackupAsync(VaultUri);
+ }
+ catch (RequestFailedException ex)
+ {
+ Assert.AreEqual(200, ex.Status);
+ return;
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}");
+ return;
+ }
+ }
+
+ [Test]
+ public void ThrowsWithTwoConsecutiveCaeChallenges()
+ {
+ MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2);
+
+ MockTransport credentialTransport = GetMockCredentialTransport(2);
+
+ KeyVaultBackupClient client = new(
+ VaultUri,
+ new MockCredential(credentialTransport),
+ new KeyVaultAdministrationClientOptions()
+ {
+ Transport = keyVaultTransport,
+ });
+
+ try
+ {
+ var operation = client.StartBackup(VaultUri);
+ }
+ catch (RequestFailedException ex)
+ {
+ Assert.AreEqual(401, ex.Status);
+ return;
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}");
+ return;
+ }
+ Assert.Fail("Expected RequestFailedException, but no exception was thrown.");
+ }
+ }
+}
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md
index 917ca4d22e1da..35ad51f55b5a5 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md
+++ b/sdk/keyvault/Azure.Security.KeyVault.Certificates/CHANGELOG.md
@@ -3,6 +3,7 @@
## 4.7.0-beta.1 (Unreleased)
### Features Added
+- Support for Continuous Access Evaluation (CAE).
### Breaking Changes
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs
new file mode 100644
index 0000000000000..27fca350ee8db
--- /dev/null
+++ b/sdk/keyvault/Azure.Security.KeyVault.Certificates/tests/ContinuousAccessEvaluationTests.cs
@@ -0,0 +1,106 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.TestFramework;
+using Azure.Security.KeyVault.Tests;
+using NUnit.Framework;
+
+namespace Azure.Security.KeyVault.Certificates.Tests
+{
+ [NonParallelizable]
+ internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase
+ {
+ [SetUp]
+ public void Setup()
+ {
+ ChallengeBasedAuthenticationPolicy.ClearCache();
+ }
+
+ [Test]
+ [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")]
+ public async Task VerifyCaeClaims(string challenge, string expectedClaims)
+ {
+ int callCount = 0;
+
+ MockResponse responseWithSecret = new MockResponse(200)
+ .WithContent(@"{
+ ""id"": ""https://foo.vault.azure.net/certificates/1/foo"",
+ ""cer"": ""Zm9v"",
+ ""attributes"": {
+ },
+ ""pending"": {
+ ""id"": ""foo""
+ }
+ }");
+
+ MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithSecret);
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ if (callCount == 0)
+ {
+ // The first challenge should not have any claims.
+ Assert.IsNull(r.Claims);
+ }
+ else if (callCount == 1)
+ {
+ Assert.AreEqual(expectedClaims, r.Claims);
+ }
+ else
+ {
+ Assert.Fail("unexpected token request");
+ }
+ Interlocked.Increment(ref callCount);
+ Assert.AreEqual(true, r.IsCaeEnabled);
+
+ return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2));
+ }, true);
+
+ CertificateClient client = new(
+ VaultUri,
+ credential,
+ new CertificateClientOptions()
+ {
+ Transport = transport,
+ });
+
+ Response response = await client.GetCertificateAsync("certificate");
+ Assert.AreEqual(200, response.GetRawResponse().Status);
+ }
+
+ [Test]
+ public void ThrowsWithTwoConsecutiveCaeChallenges()
+ {
+ MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2);
+
+ MockTransport credentialTransport = GetMockCredentialTransport(2);
+
+ CertificateClient client = new(
+ VaultUri,
+ new MockCredential(credentialTransport),
+ new CertificateClientOptions()
+ {
+ Transport = keyVaultTransport,
+ });
+
+ try
+ {
+ client.GetCertificate("certificate");
+ }
+ catch (RequestFailedException ex)
+ {
+ Assert.AreEqual(401, ex.Status);
+ return;
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}");
+ return;
+ }
+ Assert.Fail("Expected RequestFailedException, but no exception was thrown.");
+ }
+ }
+}
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md
index 1734a9546135d..1b20dcc967a67 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md
+++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/CHANGELOG.md
@@ -3,6 +3,7 @@
## 4.7.0-beta.1 (Unreleased)
### Features Added
+- Support for Continuous Access Evaluation (CAE).
### Breaking Changes
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs
new file mode 100644
index 0000000000000..6f227b954781c
--- /dev/null
+++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/ContinuousAccessEvaluationTests.cs
@@ -0,0 +1,120 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.TestFramework;
+using Azure.Security.KeyVault.Tests;
+using NUnit.Framework;
+
+namespace Azure.Security.KeyVault.Keys.Tests
+{
+ [NonParallelizable]
+ internal class ContinuousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase
+ {
+ [SetUp]
+ public void Setup()
+ {
+ ChallengeBasedAuthenticationPolicy.ClearCache();
+ }
+
+ [Test]
+ [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")]
+ public async Task VerifyCaeClaims(string challenge, string expectedClaims)
+ {
+ int callCount = 0;
+
+ MockResponse responseWithKey = new MockResponse(200)
+ .WithContent(@"{
+ ""key"": {
+ ""kid"": ""https://heathskeyvault.vault.azure.net/keys/625710934/ef3685592e1c4e839206aaa10f0f058e"",
+ ""kty"": ""RSA"",
+ ""key_ops"": [
+ ""encrypt"",
+ ""decrypt"",
+ ""sign"",
+ ""verify"",
+ ""wrapKey"",
+ ""unwrapKey""
+ ],
+ ""n"": ""foo"",
+ ""e"": ""AQAB""
+ },
+ ""attributes"": {
+ ""enabled"": true,
+ ""created"": 1613807137,
+ ""updated"": 1613807137,
+ ""recoveryLevel"": ""Recoverable\u002BPurgeable"",
+ ""recoverableDays"": 90
+ }
+ }");
+
+ MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithKey);
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ if (callCount == 0)
+ {
+ // The first challenge should not have any claims.
+ Assert.IsNull(r.Claims);
+ }
+ else if (callCount == 1)
+ {
+ Assert.AreEqual(expectedClaims, r.Claims);
+ }
+ else
+ {
+ Assert.Fail("unexpected token request");
+ }
+ Interlocked.Increment(ref callCount);
+ Assert.AreEqual(true, r.IsCaeEnabled);
+
+ return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2));
+ }, true);
+
+ KeyClient client = new(
+ VaultUri,
+ credential,
+ new KeyClientOptions()
+ {
+ Transport = transport,
+ });
+
+ Response response = await client.GetKeyAsync("key");
+ Assert.AreEqual(200, response.GetRawResponse().Status);
+ }
+
+ [Test]
+ public void ThrowsWithTwoConsecutiveCaeChallenges()
+ {
+ MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2);
+
+ MockTransport credentialTransport = GetMockCredentialTransport(2);
+
+ KeyClient client = new(
+ VaultUri,
+ new MockCredential(credentialTransport),
+ new KeyClientOptions()
+ {
+ Transport = keyVaultTransport,
+ });
+
+ try
+ {
+ client.GetKey("key");
+ }
+ catch (RequestFailedException ex)
+ {
+ Assert.AreEqual(401, ex.Status);
+ return;
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}");
+ return;
+ }
+ Assert.Fail("Expected RequestFailedException, but no exception was thrown.");
+ }
+ }
+}
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md b/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md
index 6e2354363b267..07b929203c42c 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md
+++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/CHANGELOG.md
@@ -3,6 +3,7 @@
## 4.7.0-beta.1 (Unreleased)
### Features Added
+- Support for Continuous Access Evaluation (CAE).
### Breaking Changes
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs
index 50c53bc6aed8f..7330a6ad028ae 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs
+++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ChallengeBasedAuthenticationPolicyTests.cs
@@ -199,6 +199,18 @@ public async Task ReauthenticatesWhenTenantChanged()
Assert.AreEqual("secret-value", response.Value.Value);
}
+ [Test]
+ public void GetClaimsFromChallengeHeaders()
+ {
+ MockResponse response401WithClaims = new MockResponse(401)
+ .WithHeader("WWW-Authenticate", @"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsiYWNycyI6eyJlc3NlbnRpYWwiOnRydWUsInZhbHVlIjoiY3AxIn19fQ==""");
+ Assert.AreEqual(ChallengeBasedAuthenticationPolicy.getDecodedClaimsParameter("insufficient_claims", response401WithClaims), @"{""access_token"":{""acrs"":{""essential"":true,""value"":""cp1""}}}");
+
+ MockResponse response401 = new MockResponse(401)
+ .WithHeader("WWW-Authenticate", @"Bearer authorization=""https://login.windows.net/de763a21-49f7-4b08-a8e1-52c8fbc103b4"", resource=""https://vault.azure.net""");
+ Assert.IsNull(ChallengeBasedAuthenticationPolicy.getDecodedClaimsParameter(null, response401));
+ }
+
private class MockTransportBuilder
{
private const string AuthorizationHeader = "Authorization";
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs
new file mode 100644
index 0000000000000..dad44a925d3e0
--- /dev/null
+++ b/sdk/keyvault/Azure.Security.KeyVault.Secrets/tests/ContinousAccessEvaluationTests.cs
@@ -0,0 +1,159 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.TestFramework;
+using Azure.Security.KeyVault.Tests;
+using NUnit.Framework;
+
+namespace Azure.Security.KeyVault.Secrets.Tests
+{
+ [NonParallelizable]
+ internal class ContinousAccessEvaluationTests : ContinuousAccessEvaluationTestsBase
+ {
+ [SetUp]
+ public void Setup()
+ {
+ ChallengeBasedAuthenticationPolicy.ClearCache();
+ }
+
+ [Test]
+ [TestCase(@"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""", """{"access_token":{"nbf":{"essential":true,"value":"1726077595"},"xms_caeerror":{"value":"10012"}}}""")]
+ public async Task VerifyCaeClaims(string challenge, string expectedClaims)
+ {
+ int callCount = 0;
+
+ MockResponse responseWithSecret = new MockResponse(200)
+ {
+ ContentStream = new KeyVaultSecret("test-secret", "secret-value").ToStream(),
+ };
+
+ MockTransport transport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 1, final200response: responseWithSecret);
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ if (callCount == 0)
+ {
+ // The first challenge should not have any claims.
+ Assert.IsNull(r.Claims);
+ }
+ else if (callCount == 1)
+ {
+ Assert.AreEqual(expectedClaims, r.Claims);
+ }
+ else
+ {
+ Assert.Fail("unexpected token request");
+ }
+ Interlocked.Increment(ref callCount);
+ Assert.AreEqual(true, r.IsCaeEnabled);
+
+ return new(callCount.ToString(), DateTimeOffset.Now.AddHours(2));
+ }, true);
+
+ SecretClient client = new(
+ VaultUri,
+ credential,
+ new SecretClientOptions()
+ {
+ Transport = transport,
+ });
+
+ Response response = await client.GetSecretAsync("test-secret");
+ Assert.AreEqual(200, response.GetRawResponse().Status);
+ Assert.AreEqual("secret-value", response.Value.Value);
+ }
+
+ [Test]
+ public void ThrowsWithTwoConsecutiveCaeChallenges()
+ {
+ MockTransport keyVaultTransport = GetMockTransportWithCaeChallenges(numberOfCaeChallenges: 2);
+
+ MockTransport credentialTransport = GetMockCredentialTransport(2);
+
+ SecretClient client = new(
+ VaultUri,
+ new MockCredential(credentialTransport),
+ new SecretClientOptions()
+ {
+ Transport = keyVaultTransport,
+ });
+
+ try
+ {
+ client.GetSecret("test-secret");
+ }
+ catch (RequestFailedException ex)
+ {
+ Assert.AreEqual(401, ex.Status);
+ return;
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail($"Expected RequestFailedException, but got {ex.GetType()}");
+ return;
+ }
+ Assert.Fail("Expected RequestFailedException, but no exception was thrown.");
+ }
+
+ [Test]
+ public void ensureTokenFromClaimsChallengeGetsUsed()
+ {
+ MockTransport transport = new(new[]
+ {
+ defaultInitialChallenge,
+
+ new MockResponse(200)
+ .WithJson("""
+ {
+ "token_type": "Bearer",
+ "expires_in": 3599,
+ "resource": "https://vault.azure.net",
+ "access_token": "TOKEN_1"
+ }
+ """),
+
+ new MockResponse(200)
+ {
+ ContentStream = new KeyVaultSecret("test-secret", "secret-value").ToStream(),
+ },
+
+ defaultCaeChallenge,
+
+ new MockResponse(200)
+ .WithJson("""
+ {
+ "token_type": "Bearer",
+ "expires_in": 3599,
+ "resource": "https://vault.azure.net",
+ "access_token": "TOKEN_2"
+ }
+ """),
+
+ new MockResponse(200)
+ {
+ ContentStream = new KeyVaultSecret("test-secret2", "secret-value").ToStream(),
+ },
+ });
+
+ SecretClient client = new(
+ VaultUri,
+ new MockCredential(transport),
+ new SecretClientOptions()
+ {
+ Transport = transport,
+ });
+
+ _ = client.GetSecret("test-secret");
+ _ = client.GetSecret("test-secret2");
+
+ Assert.IsTrue(transport.Requests[2].Headers.TryGetValue("Authorization", out string authorizationValue));
+ Assert.AreEqual("Bearer TOKEN_1", authorizationValue);
+
+ Assert.IsTrue(transport.Requests[5].Headers.TryGetValue("Authorization", out string authorizationValue2));
+ Assert.AreEqual("Bearer TOKEN_2", authorizationValue2);
+ }
+ }
+}
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs b/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs
index ec82ee2c1d630..0c52c7ba688fb 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs
+++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/src/ChallengeBasedAuthenticationPolicy.cs
@@ -6,6 +6,7 @@
using System;
using System.Collections.Concurrent;
using System.Globalization;
+using System.Net;
using System.Threading.Tasks;
namespace Azure.Security.KeyVault
@@ -51,7 +52,7 @@ private async ValueTask AuthorizeRequestInternal(HttpMessage message, bool async
if (_challenge != null)
{
// We fetched the challenge from the cache, but we have not initialized the Scopes in the base yet.
- var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId);
+ var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId, isCaeEnabled: true);
if (async)
{
await AuthenticateAndAuthorizeRequestAsync(message, context).ConfigureAwait(false);
@@ -84,6 +85,29 @@ protected override ValueTask AuthorizeRequestOnChallengeAsync(HttpMessage
protected override bool AuthorizeRequestOnChallenge(HttpMessage message)
=> AuthorizeRequestOnChallengeAsyncInternal(message, false).EnsureCompleted();
+ ///
+ /// Gets the claims parameter from the challenge response.
+ /// If there are no claims, returns null.
+ ///
+ /// The error message from the service.
+ /// The response from the service which contains the headers.
+ /// A string with the decoded claims if present, otherwise null
+ internal static string getDecodedClaimsParameter(string error, Response response)
+ {
+ // According to docs https://learn.microsoft.com/en-us/entra/identity-platform/claims-challenge?tabs=dotnet#claims-challenge-header-format,
+ // the error message must be "insufficient_claims" when a claims challenge should be generated.
+ if (error == "insufficient_claims")
+ {
+ return AuthorizationChallengeParser.GetChallengeParameterFromResponse(response, "Bearer", "claims") switch
+ {
+ { Length: 0 } => null,
+ string enc => System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(enc))
+ };
+ }
+
+ return null;
+ }
+
private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessage message, bool async)
{
if (message.Request.Content == null && message.TryGetProperty(KeyVaultStashedContentKey, out var content))
@@ -91,8 +115,10 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa
message.Request.Content = content as RequestContent;
}
+ string error = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "error");
string authority = GetRequestAuthority(message.Request);
string scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "resource");
+
if (scope != null)
{
scope += "/.default";
@@ -102,6 +128,15 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa
scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "scope");
}
+ // Handle CAE Challenges
+ string claims = getDecodedClaimsParameter(error, message.Response);
+ if (claims != null)
+ {
+ // Get the scope from the cache
+ s_challengeCache.TryGetValue(authority, out _challenge);
+ scope = _challenge.Scopes[0];
+ }
+
if (scope is null)
{
if (s_challengeCache.TryGetValue(authority, out _challenge))
@@ -140,7 +175,7 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa
s_challengeCache[authority] = _challenge;
}
- var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId);
+ var context = new TokenRequestContext(_challenge.Scopes, parentRequestId: message.Request.ClientRequestId, tenantId: _challenge.TenantId, isCaeEnabled: true, claims: claims);
if (async)
{
await AuthenticateAndAuthorizeRequestAsync(message, context).ConfigureAwait(false);
@@ -153,6 +188,81 @@ private async ValueTask AuthorizeRequestOnChallengeAsyncInternal(HttpMessa
return true;
}
+ ///
+ public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline)
+ {
+ return ProcessAsyncInternal(message, pipeline, true);
+ }
+
+ ///
+ public override void Process(HttpMessage message, ReadOnlyMemory pipeline)
+ {
+ ProcessAsyncInternal(message, pipeline, false).EnsureCompleted();
+ }
+
+ private async ValueTask ProcessAsyncInternal(HttpMessage message, ReadOnlyMemory pipeline, bool async)
+ {
+ if (message.Request.Uri.Scheme != Uri.UriSchemeHttps)
+ {
+ throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints.");
+ }
+
+ if (async)
+ {
+ await AuthorizeRequestAsync(message).ConfigureAwait(false);
+ await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
+ }
+ else
+ {
+ AuthorizeRequest(message);
+ ProcessNext(message, pipeline);
+ }
+
+ // Check if we have received a challenge or we have not yet issued the first request.
+ if (message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(HttpHeader.Names.WwwAuthenticate))
+ {
+ // Attempt to get the TokenRequestContext based on the challenge.
+ // If we fail to get the context, the challenge was not present or invalid.
+ // If we succeed in getting the context, authenticate the request and pass it up the policy chain.
+ if (async)
+ {
+ if (await AuthorizeRequestOnChallengeAsync(message).ConfigureAwait(false))
+ {
+ await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
+ }
+ }
+ else
+ {
+ if (AuthorizeRequestOnChallenge(message))
+ {
+ ProcessNext(message, pipeline);
+ }
+ }
+
+ // Handle the scenario in which we get a CAE challenge back.
+ if (message.Response.Status == (int)HttpStatusCode.Unauthorized
+ && message.Response.Headers.Contains(HttpHeader.Names.WwwAuthenticate)
+ && AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "claims") != null)
+ {
+ if (async)
+ {
+ if (await AuthorizeRequestOnChallengeAsync(message).ConfigureAwait(false))
+ {
+ await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
+ }
+ }
+ else
+ {
+ if (AuthorizeRequestOnChallenge(message))
+ {
+ ProcessNext(message, pipeline);
+ }
+ }
+ }
+ // If we get a second CAE challenge, an unlikely scenario, we do not attempt to re-authenticate.
+ }
+ }
+
internal class ChallengeParameters
{
internal ChallengeParameters(Uri authorizationUri, string[] scopes)
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems
index 9f553875f304a..55413c17628b9 100644
--- a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems
+++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/Azure.Security.KeyVault.Shared.Tests.projitems
@@ -9,6 +9,7 @@
Azure.Security.KeyVault.Tests
+
diff --git a/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs
new file mode 100644
index 0000000000000..af2ee53ff9139
--- /dev/null
+++ b/sdk/keyvault/Azure.Security.KeyVault.Shared/tests/ContinuousAccessEvaluationTestsBase.cs
@@ -0,0 +1,171 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Threading.Tasks;
+using System.Threading;
+using Azure.Core.TestFramework;
+using Azure.Core;
+using System.IO;
+using System.Text.Json;
+using Azure.Core.Pipeline;
+
+namespace Azure.Security.KeyVault.Tests
+{
+ internal class ContinuousAccessEvaluationTestsBase
+ {
+ protected MockResponse defaultCaeChallenge = new MockResponse(401).WithHeader("WWW-Authenticate", @"Bearer realm="""", authorization_uri=""https://login.microsoftonline.com/common/oauth2/authorize"", error=""insufficient_claims"", claims=""eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==""");
+
+ protected MockResponse defaultInitialChallenge = new MockResponse(401).WithHeader("WWW-Authenticate", @"Bearer authorization=""https://login.windows.net/de763a21-49f7-4b08-a8e1-52c8fbc103b4"", resource=""https://vault.azure.net""");
+
+ private const string VaultHost = "test.vault.azure.net";
+ protected Uri VaultUri => new Uri("https://" + VaultHost);
+
+ protected MockTransport GetMockTransportWithCaeChallenges(int numberOfCaeChallenges = 1, MockResponse final200response = null )
+ {
+ if (numberOfCaeChallenges < 1)
+ {
+ throw new ArgumentOutOfRangeException(nameof(numberOfCaeChallenges), "Number of CAE challenges must be greater than or equal to 1.");
+ }
+
+ var responses = new List { defaultInitialChallenge };
+ for (int i = 0; i < numberOfCaeChallenges; i++)
+ {
+ responses.Add(defaultCaeChallenge);
+ }
+ if (final200response != null)
+ {
+ responses.Add(final200response);
+ }
+ return new MockTransport(responses.ToArray());
+ }
+
+ protected MockTransport GetMockCredentialTransport(int numberOfTokenResponses = 1)
+ {
+ if (numberOfTokenResponses < 1)
+ {
+ throw new ArgumentOutOfRangeException(nameof(numberOfTokenResponses), "Number of token responses must be greater than or equal to 1.");
+ }
+
+ var responses = new List();
+ for (int i = 0; i < numberOfTokenResponses; i++)
+ {
+ responses.Add(new MockResponse(200)
+ .WithJson("""
+ {
+ "token_type": "Bearer",
+ "expires_in": 3599,
+ "resource": "https://vault.azure.net",
+ "access_token": "foo"
+ }
+ """));
+ }
+ return new MockTransport(responses.ToArray());
+ }
+
+ protected class TokenCredentialStub : TokenCredential
+ {
+ public TokenCredentialStub() { }
+
+ public TokenCredentialStub(Func handler, bool isAsync)
+ {
+ setCallBack(handler, isAsync);
+ }
+
+ private Func> _getTokenAsyncHandler;
+ private Func _getTokenHandler;
+
+ public void setCallBack(Func handler, bool isAsync)
+ {
+ if (isAsync)
+ {
+#pragma warning disable 1998
+ _getTokenAsyncHandler = async (r, c) => handler(r, c);
+#pragma warning restore 1998
+ }
+ else
+ {
+ _getTokenHandler = handler;
+ }
+ }
+
+ public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
+ => _getTokenAsyncHandler(requestContext, cancellationToken);
+
+ public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
+ => _getTokenHandler(requestContext, cancellationToken);
+ }
+
+ protected class MockCredential : TokenCredential
+ {
+ private readonly HttpPipeline _pipeline;
+ private readonly string _tenantId;
+ private readonly string _clientId;
+ private readonly string _clientSecret;
+ private const string TenantId = "72f988bf-86f1-41af-91ab-2d7cd011db47";
+
+ public MockCredential(MockTransport transport, string tenantId = TenantId, string clientId = "test_id", string clientSecret = "test_secret")
+ {
+ _pipeline = new HttpPipeline(transport);
+ _tenantId = tenantId ?? throw new ArgumentNullException(nameof(tenantId));
+ _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
+ _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret));
+ }
+
+ public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) => GetTokenAsync(requestContext, cancellationToken).EnsureCompleted();
+
+ public override async ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
+ {
+ Request request = _pipeline.CreateRequest();
+ request.Method = RequestMethod.Post;
+ request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType);
+
+ request.Uri.Reset(new Uri($"https://login.windows.net/{_tenantId}/oauth2/v2.0/token"));
+
+ string body = $"response_type=token&grant_type=client_credentials&client_id={Uri.EscapeDataString(_clientId)}&client_secret={Uri.EscapeDataString(_clientSecret)}&scope={Uri.EscapeDataString(string.Join(" ", requestContext.Scopes))}";
+ ReadOnlyMemory content = Encoding.UTF8.GetBytes(body).AsMemory();
+ request.Content = RequestContent.Create(content);
+
+ Response response = await _pipeline.SendRequestAsync(request, cancellationToken);
+ if (response.Status == 200 || response.Status == 201)
+ {
+ return await DeserializeAsync(response.ContentStream, cancellationToken);
+ }
+
+ throw new RequestFailedException(response.Status, response.ReasonPhrase);
+ }
+
+ private static async Task DeserializeAsync(Stream content, CancellationToken cancellationToken)
+ {
+ using (JsonDocument json = await JsonDocument.ParseAsync(content, default, cancellationToken).ConfigureAwait(false))
+ {
+ return Deserialize(json.RootElement);
+ }
+ }
+
+ private static AccessToken Deserialize(JsonElement json)
+ {
+ string accessToken = null;
+ DateTimeOffset expiresOn = DateTimeOffset.MaxValue;
+
+ foreach (JsonProperty prop in json.EnumerateObject())
+ {
+ switch (prop.Name)
+ {
+ case "access_token":
+ accessToken = prop.Value.GetString();
+ break;
+
+ case "expires_in":
+ expiresOn = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(prop.Value.GetInt64());
+ break;
+ }
+ }
+
+ return new AccessToken(accessToken, expiresOn);
+ }
+ }
+ }
+}