Skip to content

Commit

Permalink
Verify authentication tag length (#2569)
Browse files Browse the repository at this point in the history
* initial changes to verify auth tag length + app context switch + unit test

* use original tag length

* test passes

* add app context switch test

* update tag lengths to size in bytes, implement tests

* replace previous auth tag length check

* address comments

* add gcm test case

* update test case name

* update log message w/ aka.ms link, add in line that was removed

* move curly brace
  • Loading branch information
kellyyangsong authored May 1, 2024
1 parent d353b5a commit d51c2ad
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Security.Cryptography;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using Microsoft.IdentityModel.Logging;

namespace Microsoft.IdentityModel.Tokens
Expand Down Expand Up @@ -32,6 +33,7 @@ private struct AuthenticatedKeys
private DecryptionDelegate DecryptFunction;
private EncryptionDelegate EncryptFunction;
private const string _className = "Microsoft.IdentityModel.Tokens.AuthenticatedEncryptionProvider";
internal const string _skipValidationOfAuthenticationTagLength = "Switch.Microsoft.IdentityModel.SkipAuthenticationTagLengthValidation";

/// <summary>
/// Initializes a new instance of the <see cref="AuthenticatedEncryptionProvider"/> class used for encryption and decryption.
Expand Down Expand Up @@ -165,6 +167,12 @@ private AuthenticatedEncryptionResult EncryptWithAesCbc(byte[] plaintext, byte[]
private byte[] DecryptWithAesCbc(byte[] ciphertext, byte[] authenticatedData, byte[] iv, byte[] authenticationTag)
{
// Verify authentication Tag
if (ShouldValidateAuthenticationTagLength()
&& SymmetricSignatureProvider.ExpectedSignatureSizeInBytes.TryGetValue(Algorithm, out int expectedTagLength)
&& expectedTagLength != authenticationTag.Length)
throw LogHelper.LogExceptionMessage(new SecurityTokenDecryptionFailedException(
LogHelper.FormatInvariant(LogMessages.IDX10625, authenticationTag.Length, expectedTagLength, Base64UrlEncoder.Encode(authenticationTag), Algorithm)));

byte[] al = Utility.ConvertToBigEndian(authenticatedData.Length * 8);
byte[] macBytes = new byte[authenticatedData.Length + iv.Length + ciphertext.Length + al.Length];
Array.Copy(authenticatedData, 0, macBytes, 0, authenticatedData.Length);
Expand All @@ -189,6 +197,11 @@ private byte[] DecryptWithAesCbc(byte[] ciphertext, byte[] authenticatedData, by
}
}

private static bool ShouldValidateAuthenticationTagLength()
{
return !(AppContext.TryGetSwitch(_skipValidationOfAuthenticationTagLength, out bool skipValidation) && skipValidation);
}

private AuthenticatedKeys CreateAuthenticatedKeys()
{
ValidateKeySize(Key, Algorithm);
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.IdentityModel.Tokens/LogMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ internal static class LogMessages
// public const string IDX10622 = "IDX10622:";
// public const string IDX10623 = "IDX10623:";
// public const string IDX10624 = "IDX10624:";
public const string IDX10625 = "IDX10625: Failed to verify the authenticationTag length, the actual tag length '{0}' does not match the expected tag length '{1}'. authenticationTag: '{2}', algorithm: '{3}' See: https://aka.ms/IdentityModel/SkipAuthenticationTagLengthValidation";
// public const string IDX10627 = "IDX10627:";
public const string IDX10628 = "IDX10628: Cannot set the MinimumSymmetricKeySizeInBits to less than '{0}'.";
public const string IDX10630 = "IDX10630: The '{0}' for signing cannot be smaller than '{1}' bits. KeySize: '{2}'.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.IdentityModel.Tokens.Jwt.Tests;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.InteropServices;
using System.Security.Claims;
using System.Security.Cryptography;
Expand Down Expand Up @@ -86,9 +85,9 @@ public void Base64UrlEncodedUnsignedJwtHeader()
public void CreateTokenThrowsNullArgumentException()
{
var handler = new JsonWebTokenHandler();
Assert.Throws<ArgumentNullException>(() => handler.CreateToken(null, Default.SymmetricEncryptingCredentials, new Dictionary<string, object> { {"key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", (EncryptingCredentials) null, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", Default.SymmetricEncryptingCredentials, (Dictionary<string, object>) null));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken(null, Default.SymmetricEncryptingCredentials, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", (EncryptingCredentials)null, new Dictionary<string, object> { { "key", "value" } }));
Assert.Throws<ArgumentNullException>(() => handler.CreateToken("Payload", Default.SymmetricEncryptingCredentials, (Dictionary<string, object>)null));
}

[Theory, MemberData(nameof(TokenValidationClaimsTheoryData))]
Expand Down Expand Up @@ -276,7 +275,7 @@ public static TheoryData<JwtTheoryData> SegmentTheoryData()
theoryData);


JwtTestData.InvalidEncodedSegmentsData("", theoryData);
JwtTestData.InvalidEncodedSegmentsData("", theoryData);
JwtTestData.ValidEncodedSegmentsData(theoryData);

return theoryData;
Expand Down Expand Up @@ -476,7 +475,7 @@ public static TheoryData<CreateTokenTheoryData> CreateJWEWithAesGcmTheoryData

tokenHandler.InboundClaimTypeMap.Clear();
var encryptionCredentials = KeyingMaterial.DefaultSymmetricEncryptingCreds_AesGcm128;
encryptionCredentials.CryptoProviderFactory = new CryptoProviderFactoryMock();
encryptionCredentials.CryptoProviderFactory = new CryptoProviderFactoryForGcm();
return new TheoryData<CreateTokenTheoryData>
{
new CreateTokenTheoryData
Expand Down Expand Up @@ -2824,7 +2823,7 @@ public static TheoryData<CreateTokenTheoryData> RoundTripJWEKeyWrapTestCases

// Test checks to make sure that default times are correctly added to the token
// upon token creation.
[Fact (Skip = "Rewrite test to use claims, string will not succeed")]
[Fact(Skip = "Rewrite test to use claims, string will not succeed")]
public void SetDefaultTimesOnTokenCreation()
{
// when the payload is passed as a string to JsonWebTokenHandler.CreateToken, we no longer
Expand Down Expand Up @@ -2980,9 +2979,9 @@ public async Task ValidateJsonWebTokenClaimMapping()
}
};

if(jsonValidationResult.IsValid && jwtValidationResult.IsValid)
if (jsonValidationResult.IsValid && jwtValidationResult.IsValid)
{
if(!IdentityComparer.AreEqual(jsonValidationResult, jwtValidationResult, context))
if (!IdentityComparer.AreEqual(jsonValidationResult, jwtValidationResult, context))
{
context.AddDiff("jsonValidationResult.IsValid && jwtValidationResult.IsValid, Validation results are not equal");
}
Expand Down Expand Up @@ -3216,7 +3215,7 @@ public void ValidateJWS(JwtTheoryData theoryData)
try
{
var handler = new JsonWebTokenHandler();
var validationResult =handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;
var validationResult = handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;
if (validationResult.Exception != null)
{
if (validationResult.IsValid)
Expand Down Expand Up @@ -3568,7 +3567,7 @@ public void ValidateJWSWithLastKnownGood(JwtTheoryData theoryData)
var setupValidationResult = handler.ValidateTokenAsync(theoryData.Token, theoryData.ValidationParameters).Result;

theoryData.ValidationParameters.ValidateWithLKG = previousValidateWithLKG;

if (setupValidationResult.Exception != null)
{
if (setupValidationResult.IsValid)
Expand Down Expand Up @@ -4189,6 +4188,143 @@ public static TheoryData<CreateTokenTheoryData> IncludeSecurityTokenOnFailureTes
},
};
}

[Theory, MemberData(nameof(ValidateAuthenticationTagLengthTheoryData))]
public void ValidateTokenAsync_ModifiedAuthNTag(CreateTokenTheoryData theoryData)
{
// arrange
AppContext.SetSwitch(AuthenticatedEncryptionProvider._skipValidationOfAuthenticationTagLength, theoryData.EnableAppContextSwitch);
var payload = new JObject()
{
{ JwtRegisteredClaimNames.Email, "[email protected]" },
{ JwtRegisteredClaimNames.GivenName, "Bob" },
{ JwtRegisteredClaimNames.Iss, "http://Default.Issuer.com"},
{ JwtRegisteredClaimNames.Aud, "http://Default.Audience.com" },
{ JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(DateTime.Now).ToString() },
{ JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(DateTime.Now).ToString() },
{ JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(DateTime.Now.AddDays(1)).ToString() },
}.ToString();

var jsonWebTokenHandler = new JsonWebTokenHandler();
var signingCredentials = Default.SymmetricSigningCredentials;

if (SupportedAlgorithms.IsAesGcm(theoryData.Algorithm))
{
theoryData.EncryptingCredentials.CryptoProviderFactory = new CryptoProviderFactoryForGcm();
}

var jwe = jsonWebTokenHandler.CreateToken(payload, signingCredentials, theoryData.EncryptingCredentials);
var jweWithExtraCharacters = jwe + "_cannoli_hunts_truffles_";

// act
// calling ValidateTokenAsync.Result to prevent tests from sharing app context switch property
// normally, we would want to await ValidateTokenAsync().ConfigureAwait(false)
var tokenValidationResult = jsonWebTokenHandler.ValidateTokenAsync(jweWithExtraCharacters, theoryData.ValidationParameters).Result;

// assert
Assert.Equal(theoryData.IsValid, tokenValidationResult.IsValid);
}

public static TheoryData<CreateTokenTheoryData> ValidateAuthenticationTagLengthTheoryData()
{
var signingCredentials512 = new SigningCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Sha512);
return new TheoryData<CreateTokenTheoryData>()
{
new("Aes256Gcm_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes256Gcm,
EncryptingCredentials = KeyingMaterial.DefaultSymmetricEncryptingCreds_AesGcm256,
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.DefaultSymmetricSigningCreds_256_Sha2.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A128CBC-HS256_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes128CbcHmacSha256,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes128CbcHmacSha256),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A192CBC-HS384_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes192CbcHmacSha384,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes192CbcHmacSha384),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A256CBC-HS512_IsNotValidByDefault")
{
Algorithm = SecurityAlgorithms.Aes256CbcHmacSha512,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes256CbcHmacSha512),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = signingCredentials512.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = false
},
new("A128CBC-HS256_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
Algorithm = SecurityAlgorithms.Aes128CbcHmacSha256,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes128CbcHmacSha256),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
},
new("A192CBC-HS384_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
Algorithm = SecurityAlgorithms.Aes192CbcHmacSha384,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes192CbcHmacSha384),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = KeyingMaterial.JsonWebKeyRsa256SigningCredentials.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
},
new("A256CBC-HS512_SkipTagLengthValidationAppContextSwitchOn_IsValid")
{
EnableAppContextSwitch = true,
EncryptingCredentials = new EncryptingCredentials(KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaPKCS1, SecurityAlgorithms.Aes256CbcHmacSha512),
ValidationParameters = new TokenValidationParameters
{
TokenDecryptionKey = signingCredentials512.Key,
IssuerSigningKey = Default.SymmetricSigningKey256,
ValidAudience = "http://Default.Audience.com",
ValidIssuer = "http://Default.Issuer.com",
},
IsValid = true
}
};
}
}

public class CreateTokenTheoryData : TheoryDataBase
Expand Down Expand Up @@ -4234,24 +4370,26 @@ public CreateTokenTheoryData(string testId) : base(testId)
public IEnumerable<SecurityKey> ExpectedDecryptionKeys { get; set; }

public Dictionary<string, object> ExpectedClaims { get; set; }

public bool EnableAppContextSwitch { get; set; } = false;
}

// Overrides CryptoProviderFactory.CreateAuthenticatedEncryptionProvider to create AuthenticatedEncryptionProviderMock that provides AesGcm encryption.
public class CryptoProviderFactoryMock: CryptoProviderFactory
public class CryptoProviderFactoryForGcm : CryptoProviderFactory
{
public override AuthenticatedEncryptionProvider CreateAuthenticatedEncryptionProvider(SecurityKey key, string algorithm)
{
if (SupportedAlgorithms.IsSupportedEncryptionAlgorithm(algorithm, key) && SupportedAlgorithms.IsAesGcm(algorithm))
return new AuthenticatedEncryptionProviderMock(key, algorithm);
return new AuthenticatedEncryptionProviderForGcm(key, algorithm);

return null;
}
}

// Overrides AuthenticatedEncryptionProvider.Encrypt to offer AesGcm encryption for testing.
public class AuthenticatedEncryptionProviderMock: AuthenticatedEncryptionProvider
public class AuthenticatedEncryptionProviderForGcm : AuthenticatedEncryptionProvider
{
public AuthenticatedEncryptionProviderMock(SecurityKey key, string algorithm): base(key, algorithm)
public AuthenticatedEncryptionProviderForGcm(SecurityKey key, string algorithm) : base(key, algorithm)
{ }

public override AuthenticatedEncryptionResult Encrypt(byte[] plaintext, byte[] authenticatedData)
Expand All @@ -4275,7 +4413,7 @@ public override AuthenticatedEncryptionResult Encrypt(byte[] plaintext, byte[] a
aes.Encrypt(iv, plaintext, ciphertext, authenticationTag, authenticatedData);
}

return new AuthenticatedEncryptionResult(Key, ciphertext, iv, authenticationTag);
return new AuthenticatedEncryptionResult(Key, ciphertext, iv, authenticationTag);
}
}

Expand Down

0 comments on commit d51c2ad

Please sign in to comment.