Skip to content

Commit

Permalink
Implement lazy ClaimsIdentity creation from ValidatedToken on SAML an…
Browse files Browse the repository at this point in the history
…d SAML2 on the new validation model (#3051)

* Implemented the claimsidentity creation methods to allow ValidatedToken to lazily create the claims when accessed in SAML and SAML2 token handlers

* Added tests, updated returned ValidatedToken to generate the right ClaimsIdentity

* Addressed PR feedback
  • Loading branch information
iNinja authored Dec 9, 2024
1 parent fe4abcf commit 039e8f8
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const Microsoft.IdentityModel.Tokens.Saml.LogMessages.IDX11402 = "IDX11402: Unable to read SamlSecurityToken. Exception thrown: '{0}'." -> string
const Microsoft.IdentityModel.Tokens.Saml2.LogMessages.IDX13003 = "IDX13003: Unable to read Saml2SecurityToken. Exception thrown: '{0}'." -> string
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.StackFrames
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.CreateClaimsIdentity(Microsoft.IdentityModel.Tokens.Saml.SamlSecurityToken samlToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, string issuer) -> System.Security.Claims.ClaimsIdentity
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidatedConditions
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidatedConditions.ValidatedAudience.get -> string
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidatedConditions.ValidatedAudience.set -> void
Expand All @@ -12,12 +13,15 @@ Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidateTokenAsync(
Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidateTokenAsync(string token, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, Microsoft.IdentityModel.Tokens.CallContext callContext, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.ValidatedToken>>
Microsoft.IdentityModel.Tokens.Saml.SamlValidationError
Microsoft.IdentityModel.Tokens.Saml.SamlValidationError.SamlValidationError(Microsoft.IdentityModel.Tokens.MessageDetail messageDetail, Microsoft.IdentityModel.Tokens.ValidationFailureType validationFailureType, System.Type exceptionType, System.Diagnostics.StackFrame stackFrame, System.Exception innerException = null) -> void
Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.CreateClaimsIdentity(Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityToken samlToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, string issuer) -> System.Security.Claims.ClaimsIdentity
Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.StackFrames
Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.ValidateTokenAsync(Microsoft.IdentityModel.Tokens.SecurityToken securityToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, Microsoft.IdentityModel.Tokens.CallContext callContext, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.ValidatedToken>>
Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.ValidateTokenAsync(string token, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, Microsoft.IdentityModel.Tokens.CallContext callContext, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.ValidatedToken>>
Microsoft.IdentityModel.Tokens.Saml2.Saml2ValidationError
Microsoft.IdentityModel.Tokens.Saml2.Saml2ValidationError.Saml2ValidationError(Microsoft.IdentityModel.Tokens.MessageDetail messageDetail, Microsoft.IdentityModel.Tokens.ValidationFailureType validationFailureType, System.Type exceptionType, System.Diagnostics.StackFrame stackFrame, System.Exception innerException = null) -> void
override Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.CreateClaimsIdentityInternal(Microsoft.IdentityModel.Tokens.SecurityToken securityToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, string issuer) -> System.Security.Claims.ClaimsIdentity
override Microsoft.IdentityModel.Tokens.Saml.SamlValidationError.GetException() -> System.Exception
override Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.CreateClaimsIdentityInternal(Microsoft.IdentityModel.Tokens.SecurityToken securityToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, string issuer) -> System.Security.Claims.ClaimsIdentity
override Microsoft.IdentityModel.Tokens.Saml2.Saml2ValidationError.GetException() -> System.Exception
static Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.StackFrames.IssuerValidationFailed -> System.Diagnostics.StackFrame
static Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.StackFrames.SignatureValidationFailed -> System.Diagnostics.StackFrame
Expand All @@ -44,6 +48,7 @@ static Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.StackFrame
static Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.StackFrames.TokenNull -> System.Diagnostics.StackFrame
static Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.StackFrames.TokenValidationParametersNull -> System.Diagnostics.StackFrame
static Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.ValidateSignature(Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityToken samlToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, Microsoft.IdentityModel.Tokens.CallContext callContext) -> Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.SecurityKey>
virtual Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ProcessStatements(Microsoft.IdentityModel.Tokens.Saml.SamlSecurityToken samlToken, string issuer, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters) -> System.Collections.Generic.IEnumerable<System.Security.Claims.ClaimsIdentity>
virtual Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ReadSamlToken(string token, Microsoft.IdentityModel.Tokens.CallContext callContext) -> Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.Saml.SamlSecurityToken>
virtual Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidateConditions(Microsoft.IdentityModel.Tokens.Saml.SamlSecurityToken samlToken, Microsoft.IdentityModel.Tokens.ValidationParameters validationParameters, Microsoft.IdentityModel.Tokens.CallContext callContext) -> Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.Saml.SamlSecurityTokenHandler.ValidatedConditions>
virtual Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityTokenHandler.ReadSaml2Token(string token, Microsoft.IdentityModel.Tokens.CallContext callContext) -> Microsoft.IdentityModel.Tokens.ValidationResult<Microsoft.IdentityModel.Tokens.Saml2.Saml2SecurityToken>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
using Microsoft.IdentityModel.Logging;

#nullable enable
namespace Microsoft.IdentityModel.Tokens.Saml
{
/// <summary>
/// A <see cref="SecurityTokenHandler"/> designed for creating and validating Saml Tokens. See: http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
/// </summary>
public partial class SamlSecurityTokenHandler : SecurityTokenHandler
{
internal override ClaimsIdentity CreateClaimsIdentityInternal(SecurityToken securityToken, ValidationParameters validationParameters, string issuer)
{
return CreateClaimsIdentity((SamlSecurityToken)securityToken, validationParameters, issuer);
}

internal ClaimsIdentity CreateClaimsIdentity(SamlSecurityToken samlToken, ValidationParameters validationParameters, string issuer)
{
if (samlToken == null)
throw LogHelper.LogArgumentNullException(nameof(samlToken));

if (samlToken.Assertion == null)
throw LogHelper.LogArgumentNullException(LogMessages.IDX11110);

var actualIssuer = issuer;
if (string.IsNullOrWhiteSpace(issuer))
actualIssuer = ClaimsIdentity.DefaultIssuer;

IEnumerable<ClaimsIdentity> identities = ProcessStatements(
samlToken,
actualIssuer,
validationParameters);

return identities.First();
}

/// <summary>
/// Processes all statements to generate claims.
/// </summary>
/// <param name="samlToken">A <see cref="SamlSecurityToken"/> that will be used to create the claims.</param>
/// <param name="issuer">The issuer.</param>
/// <param name="validationParameters">The <see cref="TokenValidationParameters"/> to be used for validating the token.</param>
/// <returns>A <see cref="IEnumerable{ClaimsIdentity}"/> containing the claims from the <see cref="SamlSecurityToken"/>.</returns>
/// <exception cref="SamlSecurityTokenException">if the statement is not a <see cref="SamlSubjectStatement"/>.</exception>
internal virtual IEnumerable<ClaimsIdentity> ProcessStatements(SamlSecurityToken samlToken, string issuer, ValidationParameters validationParameters)
{
if (samlToken == null)
throw LogHelper.LogArgumentNullException(nameof(samlToken));

if (validationParameters == null)
throw LogHelper.LogArgumentNullException(nameof(validationParameters));

var identityDict = new Dictionary<SamlSubject, ClaimsIdentity>(SamlSubjectEqualityComparer);
foreach (SamlStatement? item in samlToken.Assertion.Statements)
{
if (item is not SamlSubjectStatement statement)
throw LogHelper.LogExceptionMessage(new SamlSecurityTokenException(LogMessages.IDX11515));

if (!identityDict.TryGetValue(statement.Subject, out ClaimsIdentity? identity))
{
identity = validationParameters.CreateClaimsIdentity(samlToken, issuer);
ProcessSubject(statement.Subject, identity, issuer);
identityDict.Add(statement.Subject, identity);
}

if (statement is SamlAttributeStatement attrStatement)
ProcessAttributeStatement(attrStatement, identity, issuer);
else if (statement is SamlAuthenticationStatement authnStatement)
ProcessAuthenticationStatement(authnStatement, identity, issuer);
else if (statement is SamlAuthorizationDecisionStatement authzStatement)
ProcessAuthorizationDecisionStatement(authzStatement, identity, issuer);
else
ProcessCustomSubjectStatement(statement, identity, issuer);
}

return identityDict.Values;
}
}
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
if (!conditionsResult.IsValid)
return conditionsResult.UnwrapError().AddCurrentStackFrame();

ValidationResult<ValidatedIssuer> issuerValidationResult;

try
{
ValidationResult<ValidatedIssuer> issuerValidationResult = await validationParameters.IssuerValidatorAsync(
issuerValidationResult = await validationParameters.IssuerValidatorAsync(
samlToken.Issuer,
samlToken,
validationParameters,
Expand All @@ -101,10 +103,10 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
ex);
}

ValidationResult<DateTime?>? tokenReplayValidationResult = null;

if (samlToken.Assertion.Conditions is not null)
{
ValidationResult<DateTime?> tokenReplayValidationResult;

try
{
tokenReplayValidationResult = validationParameters.TokenReplayValidator(
Expand All @@ -113,8 +115,8 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
validationParameters,
callContext);

if (!tokenReplayValidationResult.IsValid)
return tokenReplayValidationResult.UnwrapError().AddCurrentStackFrame();
if (!tokenReplayValidationResult.Value.IsValid)
return tokenReplayValidationResult.Value.UnwrapError().AddCurrentStackFrame();
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception ex)
Expand Down Expand Up @@ -165,7 +167,15 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
ex);
}

return new ValidatedToken(samlToken, this, validationParameters);
return new ValidatedToken(samlToken, this, validationParameters)
{
ValidatedAudience = conditionsResult.UnwrapResult().ValidatedAudience,
ValidatedLifetime = conditionsResult.UnwrapResult().ValidatedLifetime,
ValidatedIssuer = issuerValidationResult.UnwrapResult(),
ValidatedTokenReplayExpirationTime = tokenReplayValidationResult?.UnwrapResult(),
ValidatedSigningKey = signatureValidationResult.UnwrapResult(),
ValidatedSigningKeyLifetime = issuerSigningKeyValidationResult.UnwrapResult(),
};
}

// ValidatedConditions is basically a named tuple but using a record struct better expresses the intent.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Security.Claims;
using Microsoft.IdentityModel.Logging;

#nullable enable
namespace Microsoft.IdentityModel.Tokens.Saml2
{
/// <summary>
/// A <see cref="SecurityTokenHandler"/> designed for creating and validating Saml2 Tokens. See: http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
/// </summary>
public partial class Saml2SecurityTokenHandler : SecurityTokenHandler
{
internal override ClaimsIdentity CreateClaimsIdentityInternal(SecurityToken securityToken, ValidationParameters validationParameters, string issuer)
{
return CreateClaimsIdentity((Saml2SecurityToken)securityToken, validationParameters, issuer);
}

internal ClaimsIdentity CreateClaimsIdentity(Saml2SecurityToken samlToken, ValidationParameters validationParameters, string issuer)
{
if (samlToken == null)
throw LogHelper.LogArgumentNullException(nameof(samlToken));

if (samlToken.Assertion == null)
throw LogHelper.LogArgumentNullException(LogMessages.IDX13110);

if (validationParameters == null)
throw LogHelper.LogArgumentNullException(nameof(validationParameters));

string actualIssuer = issuer;
if (string.IsNullOrWhiteSpace(issuer))
actualIssuer = ClaimsIdentity.DefaultIssuer;

ClaimsIdentity identity = validationParameters.CreateClaimsIdentity(samlToken, actualIssuer);

ProcessSubject(samlToken.Assertion.Subject, identity, actualIssuer);
ProcessStatements(samlToken.Assertion.Statements, identity, actualIssuer);

return identity;
}
}
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
if (!conditionsResult.IsValid)
return conditionsResult.UnwrapError().AddCurrentStackFrame();

ValidationResult<ValidatedIssuer> issuerValidationResult;

try
{
ValidationResult<ValidatedIssuer> issuerValidationResult = await validationParameters.IssuerValidatorAsync(
issuerValidationResult = await validationParameters.IssuerValidatorAsync(
samlToken.Issuer,
samlToken,
validationParameters,
Expand All @@ -105,10 +107,10 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
ex);
}

ValidationResult<DateTime?>? tokenReplayValidationResult = null;

if (samlToken.Assertion.Conditions is not null)
{
ValidationResult<DateTime?> tokenReplayValidationResult;

try
{
tokenReplayValidationResult = validationParameters.TokenReplayValidator(
Expand All @@ -117,8 +119,8 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
validationParameters,
callContext);

if (!tokenReplayValidationResult.IsValid)
return tokenReplayValidationResult.UnwrapError().AddCurrentStackFrame();
if (!tokenReplayValidationResult.Value.IsValid)
return tokenReplayValidationResult.Value.UnwrapError().AddCurrentStackFrame();
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception ex)
Expand Down Expand Up @@ -168,7 +170,15 @@ internal async Task<ValidationResult<ValidatedToken>> ValidateTokenAsync(
ex);
}

return new ValidatedToken(samlToken, this, validationParameters);
return new ValidatedToken(samlToken, this, validationParameters)
{
ValidatedAudience = conditionsResult.UnwrapResult().ValidatedAudience,
ValidatedLifetime = conditionsResult.UnwrapResult().ValidatedLifetime,
ValidatedIssuer = issuerValidationResult.UnwrapResult(),
ValidatedTokenReplayExpirationTime = tokenReplayValidationResult?.UnwrapResult(),
ValidatedSigningKey = signatureValidationResult.UnwrapResult(),
ValidatedSigningKeyLifetime = issuerSigningKeyValidationResult.UnwrapResult(),
};
}

// ValidatedConditions is basically a named tuple but using a record struct better expresses the intent.
Expand Down
Loading

0 comments on commit 039e8f8

Please sign in to comment.