Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify authentication tag length #2569

Merged
merged 11 commits into from
May 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Security.Cryptography;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using Microsoft.IdentityModel.Logging;

namespace Microsoft.IdentityModel.Tokens
Expand Down Expand Up @@ -32,6 +33,7 @@ private struct AuthenticatedKeys
private DecryptionDelegate DecryptFunction;
private EncryptionDelegate EncryptFunction;
private const string _className = "Microsoft.IdentityModel.Tokens.AuthenticatedEncryptionProvider";
internal const string _skipValidationOfAuthenticationTagLength = "Switch.Microsoft.IdentityModel.SkipAuthenticationTagLengthValidation";
kellyyangsong marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Initializes a new instance of the <see cref="AuthenticatedEncryptionProvider"/> class used for encryption and decryption.
Expand Down Expand Up @@ -165,6 +167,12 @@ private AuthenticatedEncryptionResult EncryptWithAesCbc(byte[] plaintext, byte[]
private byte[] DecryptWithAesCbc(byte[] ciphertext, byte[] authenticatedData, byte[] iv, byte[] authenticationTag)
{
// Verify authentication Tag
if (ShouldValidateAuthenticationTagLength()
&& SymmetricSignatureProvider.ExpectedSignatureSizeInBytes.TryGetValue(Algorithm, out int expectedTagLength)
kellyyangsong marked this conversation as resolved.
Show resolved Hide resolved
&& expectedTagLength != authenticationTag.Length)
throw LogHelper.LogExceptionMessage(new SecurityTokenDecryptionFailedException(
LogHelper.FormatInvariant(LogMessages.IDX10625, authenticationTag.Length, expectedTagLength, Base64UrlEncoder.Encode(authenticationTag), Algorithm)));

byte[] al = Utility.ConvertToBigEndian(authenticatedData.Length * 8);
byte[] macBytes = new byte[authenticatedData.Length + iv.Length + ciphertext.Length + al.Length];
Array.Copy(authenticatedData, 0, macBytes, 0, authenticatedData.Length);
Expand All @@ -189,6 +197,11 @@ private byte[] DecryptWithAesCbc(byte[] ciphertext, byte[] authenticatedData, by
}
}

private static bool ShouldValidateAuthenticationTagLength()
{
return !(AppContext.TryGetSwitch(_skipValidationOfAuthenticationTagLength, out bool skipValidation) && skipValidation);
}

private AuthenticatedKeys CreateAuthenticatedKeys()
{
ValidateKeySize(Key, Algorithm);
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.IdentityModel.Tokens/LogMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ internal static class LogMessages
// public const string IDX10622 = "IDX10622:";
// public const string IDX10623 = "IDX10623:";
// public const string IDX10624 = "IDX10624:";
public const string IDX10625 = "IDX10625: Failed to verify the authenticationTag length, the actual tag length '{0}' does not match the expected tag length '{1}'. authenticationTag: '{2}', algorithm: '{3}' See: https://aka.ms/IdentityModel/SkipAuthenticationTagLengthValidation";
// public const string IDX10627 = "IDX10627:";
public const string IDX10628 = "IDX10628: Cannot set the MinimumSymmetricKeySizeInBits to less than '{0}'.";
public const string IDX10630 = "IDX10630: The '{0}' for signing cannot be smaller than '{1}' bits. KeySize: '{2}'.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.IdentityModel.Tokens.Jwt.Tests;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.InteropServices;
using System.Security.Claims;
using System.Security.Cryptography;
Expand Down Expand Up @@ -86,9 +85,9 @@ public void Base64UrlEncodedUnsignedJwtHeader()
public void CreateTokenThrowsNullArgumentException()
{
var handler = new JsonWebTokenHandler();
Assert.Throws<ArgumentNullException>(() => handler.CreateToken(null, Default.SymmetricEncryptingCredentials, new Dictionary<string, object> { {"key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", (EncryptingCredentials) null, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", Default.SymmetricEncryptingCredentials, (Dictionary<string, object>) null));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken(null, Default.SymmetricEncryptingCredentials, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", (EncryptingCredentials)null, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", Default.SymmetricEncryptingCredentials, (Dictionary<string, object>)null));
}

[Theory, MemberData(nameof(TokenValidationClaimsTheoryData))]
Expand Down Expand Up @@ -276,7 +275,7 @@ public static TheoryData<JwtTheoryData> SegmentTheoryData()
theoryData);


JwtTestData.InvalidEncodedSegmentsData("", theoryData);
JwtTestData.InvalidEncodedSegmentsData("", theoryData);
JwtTestData.ValidEncodedSegmentsData(theoryData);

return theoryData;
Expand Down Expand Up @@ -476,7 +475,7 @@ public static TheoryData<CreateTokenTheoryData> CreateJWEWithAesGcmTheoryData

tokenHandler.InboundClaimTypeMap.Clear();
var encryptionCredentials = KeyingMaterial.DefaultSymmetricEncryptingCreds_AesGcm128;
encryptionCredentials.CryptoProviderFactory = new CryptoProviderFactoryMock();
encryptionCredentials.CryptoProviderFactory = new CryptoProviderFactoryForGcm();
return new TheoryData<CreateTokenTheoryData>
{
new CreateTokenTheoryData
Expand Down Expand Up @@ -2824,7 +2823,7 @@ public static TheoryData<CreateTokenTheoryData> RoundTripJWEKeyWrapTestCases

// Test checks to make sure that default times are correctly added to the token
// upon token creation.
[Fact (Skip = "Rewrite test to use claims, string will not succeed")]
[Fact(Skip = "Rewrite test to use claims, string will not succeed")]
public void SetDefaultTimesOnTokenCreation()
{
// when the payload is passed as a string to JsonWebTokenHandler.CreateToken, we no longer
Expand Down Expand Up @@ -2980,9 +2979,9 @@ public async Task ValidateJsonWebTokenClaimMapping()
}
};

if(jsonValidationResult.IsValid && jwtValidationResult.IsValid)
if (jsonValidationResult.IsValid && jwtValidationResult.IsValid)
{
if(!IdentityComparer.AreEqual(jsonValidationResult, jwtValidationResult, context))
if (!IdentityComparer.AreEqual(jsonValidationResult, jwtValidationResult, context))
{
context.AddDiff("jsonValidationResult.IsValid && jwtValidationResult.IsValid, Validation results are not equal");
}
Expand Down Expand Up @@ -3216,7 +3215,7 @@ public void ValidateJWS(JwtTheoryData theoryData)
try
{
var handler = new JsonWebTokenHandler();
var validationResult =handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;
var validationResult = handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;
if (validationResult.Exception != null)
{
if (validationResult.IsValid)
Expand Down Expand Up @@ -3568,7 +3567,7 @@ public void ValidateJWSWithLastKnownGood(JwtTheoryData theoryData)
var setupValidationResult = handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;

theoryData.ValidationParameters.ValidateWithLKG = previousValidateWithLKG;

if (setupValidationResult.Exception != null)
{
if (setupValidationResult.IsValid)
Expand Down Expand Up @@ -4189,6 +4188,143 @@ public static TheoryData<CreateTokenTheoryData> IncludeSecurityTokenOnFailureTes
},
};
}

[Theory, MemberData(nameof(ValidateAuthenticationTagLengthTheoryData))]
public void ValidateTokenAsync_ModifiedAuthNTag(CreateTokenTheoryData theoryData)
{
// arrange
AppContext.SetSwitch(AuthenticatedEncryptionProvider._skipValidationOfAuthenticationTagLength, theoryData.EnableAppContextSwitch);
var payload = new JObject()
kellyyangsong marked this conversation as resolved.
Show resolved Hide resolved
{
{ JwtRegisteredClaimNames.Email, "[email protected]" },
{ JwtRegisteredClaimNames.GivenName, "Bob" },
{ JwtRegisteredClaimNames.Iss, "http://Default.Issuer.com"},
{ JwtRegisteredClaimNames.Aud, "http://Default.Audience.com" },
{ JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(DateTime.Now).ToString() },
{ JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(DateTime.Now).ToString() },
{ JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(DateTime.Now.AddDays(1)).ToString() },
kellyyangsong marked this conversation as resolved.
Show resolved Hide resolved
}.ToString();

var jsonWebTokenHandler = new JsonWebTokenHandler();
var signingCredentials = Default.SymmetricSigningCredentials;

if (SupportedAlgorithms.IsAesGcm(theoryData.Algorithm))
{
theoryData.EncryptingCredentials.CryptoProviderFactory = new CryptoProviderFactoryForGcm();
}

var jwe = jsonWebTokenHandler.CreateToken(payload, signingCredentials, theoryData.EncryptingCredentials);
var jweWithExtraCharacters = jwe + "_cannoli_hunts_truffles_";

// act
// calling ValidateTokenAsync.Result to prevent tests from sharing app context switch property
// normally, we would want to await ValidateTokenAsync().ConfigureAwait(false)
var tokenValidationResult = jsonWebTokenHandler.ValidateTokenAsync(jweWithExtraCharacters, theoryData.ValidationParameters).Result;

// assert
Assert.Equal(theoryData.IsValid, tokenValidationResult.IsValid);
}

public static TheoryData<CreateTokenTheoryData> ValidateAuthenticationTagLengthTheoryData()
{
var signingCredentials512 = new SigningCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Sha512);
return new TheoryData<CreateTokenTheoryData>()
{
new("Aes256Gcm_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes256Gcm,
EncryptingCredentials = KeyingMaterial.DefaultSymmetricEncryptingCreds_AesGcm256,
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.DefaultSymmetricSigningCreds_256_Sha2.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A128CBC-HS256_IsNotValidByDefault")
kellyyangsong marked this conversation as resolved.
Show resolved Hide resolved
{
Algorithm = SecurityAlgorithms.Aes128CbcHmacSha256,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes128CbcHmacSha256),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A192CBC-HS384_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes192CbcHmacSha384,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes192CbcHmacSha384),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A256CBC-HS512_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes256CbcHmacSha512,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes256CbcHmacSha512),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = signingCredentials512.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A128CBC-HS256_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
Algorithm = SecurityAlgorithms.Aes128CbcHmacSha256,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes128CbcHmacSha256),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
},
new("A192CBC-HS384_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
Algorithm = SecurityAlgorithms.Aes192CbcHmacSha384,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes192CbcHmacSha384),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
},
new("A256CBC-HS512_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes256CbcHmacSha512),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = signingCredentials512.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
}
};
}
}

public class CreateTokenTheoryData : TheoryDataBase
Expand Down Expand Up @@ -4234,24 +4370,26 @@ public CreateTokenTheoryData(string testId) : base(testId)
public IEnumerable<SecurityKey> ExpectedDecryptionKeys { get; set; }

public Dictionary<string, object> ExpectedClaims { get; set; }

public bool EnableAppContextSwitch { get; set; } = false;
}

// Overrides CryptoProviderFactory.CreateAuthenticatedEncryptionProvider to create AuthenticatedEncryptionProviderMock that provides AesGcm encryption.
public class CryptoProviderFactoryMock: CryptoProviderFactory
public class CryptoProviderFactoryForGcm : CryptoProviderFactory
{
public override AuthenticatedEncryptionProvider CreateAuthenticatedEncryptionProvider(SecurityKey key, string algorithm)
{
if (SupportedAlgorithms.IsSupportedEncryptionAlgorithm(algorithm, key) && SupportedAlgorithms.IsAesGcm(algorithm))
return new AuthenticatedEncryptionProviderMock(key, algorithm);
return new AuthenticatedEncryptionProviderForGcm(key, algorithm);

return null;
}
}

// Overrides AuthenticatedEncryptionProvider.Encrypt to offer AesGcm encryption for testing.
public class AuthenticatedEncryptionProviderMock: AuthenticatedEncryptionProvider
public class AuthenticatedEncryptionProviderForGcm : AuthenticatedEncryptionProvider
{
public AuthenticatedEncryptionProviderMock(SecurityKey key, string algorithm): base(key, algorithm)
public AuthenticatedEncryptionProviderForGcm(SecurityKey key, string algorithm) : base(key, algorithm)
{ }

public override AuthenticatedEncryptionResult Encrypt(byte[] plaintext, byte[] authenticatedData)
Expand All @@ -4275,7 +4413,7 @@ public override AuthenticatedEncryptionResult Encrypt(byte[] plaintext, byte[] a
aes.Encrypt(iv, plaintext, ciphertext, authenticationTag, authenticatedData);
}

return new AuthenticatedEncryptionResult(Key, ciphertext, iv, authenticationTag);
return new AuthenticatedEncryptionResult(Key, ciphertext, iv, authenticationTag);
}
}

Expand Down
Loading