-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding ChallengeCacheHandler to update challenge cache on 401 response (
#6950) * Adding ChallengeCacheHandler to update challenge cache on 401 responses with auth challenge * updates addressing PR feedback
- Loading branch information
Showing
4 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
50 changes: 50 additions & 0 deletions
50
sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Net.Http; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace Microsoft.Azure.KeyVault.Customized.Authentication | ||
{ | ||
/// <summary> | ||
/// A <see cref="DelegatingHandler"/> which will update the <see cref="HttpBearerChallengeCache"/> when a 401 response is returned with a WWW-Authenticate bearer challenge header. | ||
/// </summary> | ||
public class ChallengeCacheHandler : MessageProcessingHandler | ||
{ | ||
/// <summary> | ||
/// Returns the specified request without performing any processing | ||
/// </summary> | ||
/// <param name="request"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken) | ||
{ | ||
return request; | ||
} | ||
/// <summary> | ||
/// Updates the <see cref="HttpBearerChallengeCache"/> when the specified response has a return code of 401 | ||
/// </summary> | ||
/// <param name="response">The response to evaluate</param> | ||
/// <param name="cancellationToken">The cancellation token</param> | ||
/// <returns></returns> | ||
protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken) | ||
{ | ||
// if the response came back as 401 and the response contains a bearer challenge update the challenge cache | ||
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) | ||
{ | ||
HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response); | ||
|
||
if (challenge != null) | ||
{ | ||
// Update challenge cache | ||
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge); | ||
} | ||
} | ||
|
||
return response; | ||
} | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
254 changes: 254 additions & 0 deletions
254
sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
using Microsoft.Azure.KeyVault.Customized.Authentication; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Net; | ||
using System.Net.Http; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Xunit; | ||
using Kvp = System.Collections.Generic.KeyValuePair<string, string>; | ||
|
||
namespace Microsoft.Azure.KeyVault.Tests | ||
{ | ||
public class ChallengeCacheHandlerTests | ||
{ | ||
[Fact] | ||
public async Task CacheAddOn401Async() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, expChallenge); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheUpdateOn401Async() | ||
{ | ||
string requestUrlBase = CreateMockUrl(); | ||
|
||
string requestUrl1 = CreateMockUrl(requestUrlBase, 2); | ||
|
||
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1)); | ||
|
||
string requestUrl2 = CreateMockUrl(requestUrlBase, 2); | ||
|
||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2)); | ||
|
||
AssertChallengeCacheEntry(requestUrl1, expChallenge); | ||
|
||
AssertChallengeCacheEntry(requestUrl2, expChallenge); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNoChallengeAsync() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNon401Aysnc() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNonBearerChallengeAsync() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge) | ||
{ | ||
var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl)); | ||
|
||
if (expChallenge == null) | ||
{ | ||
Assert.Null(actChallenge); | ||
} | ||
else | ||
{ | ||
Assert.NotNull(actChallenge); | ||
|
||
Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer); | ||
|
||
Assert.Equal(expChallenge.Scope, actChallenge.Scope); | ||
|
||
Assert.Equal(expChallenge.Resource, actChallenge.Resource); | ||
} | ||
} | ||
|
||
private static string BuildChallengeString(params Kvp[] parameters) | ||
{ | ||
// remove the trailing ',' and return | ||
return BuildChallengeString("Bearer", parameters); | ||
} | ||
|
||
private static string BuildChallengeString(string challengeType, params Kvp[] parameters) | ||
{ | ||
StringBuilder buff = new StringBuilder(challengeType).Append(" "); | ||
|
||
foreach (var kvp in parameters) | ||
{ | ||
buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\","); | ||
} | ||
|
||
// remove the trailing ',' and return | ||
return buff.Remove(buff.Length - 1, 1).ToString(); | ||
} | ||
|
||
public static string CreateMockUrl(int pathCount = 0) | ||
{ | ||
return CreateMockUrl("https://" + Guid.NewGuid().ToString("N"), pathCount); | ||
} | ||
|
||
public static string CreateMockUrl(string baseUrl, int pathCount = 0) | ||
{ | ||
var buff = new StringBuilder(baseUrl); | ||
|
||
if(baseUrl.EndsWith("/")) | ||
{ | ||
buff.Remove(buff.Length - 1, 1); | ||
} | ||
|
||
for (int i = 0; i < pathCount; i++) | ||
{ | ||
buff.Append("/").Append(Guid.NewGuid().ToString("N")); | ||
} | ||
|
||
return buff.ToString(); | ||
} | ||
|
||
private class MockChallenge | ||
{ | ||
public string ChallengeType { get; set; } | ||
|
||
public string AuthorizationServer { get; set; } | ||
|
||
public string Resource { get; set; } | ||
|
||
public string Scope { get; set; } | ||
|
||
public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null) | ||
{ | ||
var mock = new MockChallenge(); | ||
mock.ChallengeType = challengeType ?? "Bearer"; | ||
mock.AuthorizationServer = authority ?? CreateMockUrl(1); | ||
mock.Resource = resource ?? CreateMockUrl(0); | ||
mock.Scope = scope ?? mock.Resource + "/.default"; | ||
return mock; | ||
} | ||
|
||
public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl) | ||
{ | ||
return new HttpBearerChallenge(new Uri(requestUrl), ToString()); | ||
} | ||
|
||
public override string ToString() | ||
{ | ||
var parameters = new List<Kvp>(); | ||
|
||
StringBuilder buff = new StringBuilder(); | ||
|
||
if(AuthorizationServer != null) | ||
{ | ||
parameters.Add(new Kvp("authorization", AuthorizationServer)); | ||
} | ||
|
||
if (Resource != null) | ||
{ | ||
parameters.Add(new Kvp("resource", Resource)); | ||
} | ||
|
||
if (Scope != null) | ||
{ | ||
parameters.Add(new Kvp("scope", Scope)); | ||
} | ||
|
||
return BuildChallengeString(ChallengeType, parameters.ToArray()); | ||
} | ||
|
||
} | ||
|
||
|
||
private class StaticChallengeResponseHandler : HttpMessageHandler | ||
{ | ||
private HttpStatusCode _statusCode; | ||
private string _challengeHeader; | ||
|
||
public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader) | ||
{ | ||
_statusCode = statusCode; | ||
_challengeHeader = challengeHeader; | ||
} | ||
|
||
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) | ||
{ | ||
return Task.FromResult(CreateResponse(request)); | ||
} | ||
|
||
private HttpResponseMessage CreateResponse(HttpRequestMessage request) | ||
{ | ||
var response = new HttpResponseMessage(_statusCode); | ||
|
||
response.RequestMessage = request; | ||
|
||
if(_challengeHeader != null) | ||
{ | ||
response.Headers.Add("WWW-Authenticate", _challengeHeader); | ||
} | ||
|
||
return response; | ||
} | ||
} | ||
} | ||
} |