Skip to content

Commit

Permalink
Replace temporary string.Split calls with manual counting (#1964)
Browse files Browse the repository at this point in the history
* Replace temporary string.Split calls with manual counting

* Update src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs
  • Loading branch information
BrennanConroy authored Jul 25, 2023
1 parent 68898be commit 28f89d2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ public virtual bool CanReadToken(string token)
// Set the maximum number of segments to MaxJwtSegmentCount + 1. This controls the number of splits and allows detecting the number of segments is too large.
// For example: "a.b.c.d.e.f.g.h" => [a], [b], [c], [d], [e], [f.g.h]. 6 segments.
// If just MaxJwtSegmentCount was used, then [a], [b], [c], [d], [e.f.g.h] would be returned. 5 segments.
string[] tokenParts = token.Split(new char[] { '.' }, JwtConstants.MaxJwtSegmentCount + 1);
if (tokenParts.Length == JwtConstants.JwsSegmentCount)
int tokenPartCount = JwtTokenUtilities.CountJwtTokenPart(token, JwtConstants.MaxJwtSegmentCount + 1);
if (tokenPartCount == JwtConstants.JwsSegmentCount)
return JwtTokenUtilities.RegexJws.IsMatch(token);
else if (tokenParts.Length == JwtConstants.JweSegmentCount)
else if (tokenPartCount == JwtConstants.JweSegmentCount)
return JwtTokenUtilities.RegexJwe.IsMatch(token);

LogHelper.LogInformation(LogMessages.IDX14107);
Expand Down
27 changes: 27 additions & 0 deletions src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,33 @@ internal static SecurityKey ResolveTokenSigningKey(string kid, string x5t, IEnum
return null;
}

/// <summary>
/// Counts the number of Jwt Token segments.
/// </summary>
/// <param name="token">The Jwt Token.</param>
/// <param name="maxCount">The maximum number of segments to count up to.</param>
/// <returns>The number of segments up to <paramref name="maxCount"/>.</returns>
internal static int CountJwtTokenPart(string token, int maxCount)
{
var count = 1;
var index = 0;
while (index < token.Length)
{
var dotIndex = token.IndexOf('.', index);
if (dotIndex < 0)
{
break;
}
count++;
index = dotIndex + 1;
if (count == maxCount)
{
break;
}
}
return count;
}

internal static IEnumerable<SecurityKey> ConcatSigningKeys(TokenValidationParameters tvp)
{
if (tvp == null)
Expand Down
27 changes: 16 additions & 11 deletions src/System.IdentityModel.Tokens.Jwt/JwtSecurityTokenHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ public override bool CanReadToken(string token)
// Set the maximum number of segments to MaxJwtSegmentCount + 1. This controls the number of splits and allows detecting the number of segments is too large.
// For example: "a.b.c.d.e.f.g.h" => [a], [b], [c], [d], [e], [f.g.h]. 6 segments.
// If just MaxJwtSegmentCount was used, then [a], [b], [c], [d], [e.f.g.h] would be returned. 5 segments.
string[] tokenParts = token.Split(new char[] { '.' }, JwtConstants.MaxJwtSegmentCount + 1);
if (tokenParts.Length == JwtConstants.JwsSegmentCount)
int tokenPartCount = JwtTokenUtilities.CountJwtTokenPart(token, JwtConstants.MaxJwtSegmentCount + 1);
if (tokenPartCount == JwtConstants.JwsSegmentCount)
{
return JwtTokenUtilities.RegexJws.IsMatch(token);
}
else if (tokenParts.Length == JwtConstants.JweSegmentCount)
else if (tokenPartCount == JwtConstants.JweSegmentCount)
{
return JwtTokenUtilities.RegexJwe.IsMatch(token);
}
Expand Down Expand Up @@ -648,8 +648,12 @@ private JwtSecurityToken CreateJwtSecurityTokenPrivate(

string rawHeader = header.Base64UrlEncode();
string rawPayload = payload.Base64UrlEncode();
string message = string.Concat(header.Base64UrlEncode(), ".", payload.Base64UrlEncode());
string rawSignature = signingCredentials == null ? string.Empty : JwtTokenUtilities.CreateEncodedSignature(message, signingCredentials);
string rawSignature = string.Empty;
if (signingCredentials != null)
{
string message = string.Concat(rawHeader, ".", rawPayload);
rawSignature = JwtTokenUtilities.CreateEncodedSignature(message, signingCredentials);
}

LogHelper.LogInformation(LogMessages.IDX12722, rawHeader, rawPayload, rawSignature);

Expand Down Expand Up @@ -688,20 +692,21 @@ private JwtSecurityToken EncryptToken(
try
{
var header = new JwtHeader(encryptingCredentials, OutboundAlgorithmMap, tokenType, additionalHeaderClaims);
AuthenticatedEncryptionResult encryptionResult = encryptionProvider.Encrypt(Encoding.UTF8.GetBytes(innerJwt.RawData), Encoding.ASCII.GetBytes(header.Base64UrlEncode()));
var encodedHeader = header.Base64UrlEncode();
AuthenticatedEncryptionResult encryptionResult = encryptionProvider.Encrypt(Encoding.UTF8.GetBytes(innerJwt.RawData), Encoding.ASCII.GetBytes(encodedHeader));
return JwtConstants.DirectKeyUseAlg.Equals(encryptingCredentials.Alg) ?
new JwtSecurityToken(
header,
innerJwt,
header.Base64UrlEncode(),
encodedHeader,
string.Empty,
Base64UrlEncoder.Encode(encryptionResult.IV),
Base64UrlEncoder.Encode(encryptionResult.Ciphertext),
Base64UrlEncoder.Encode(encryptionResult.AuthenticationTag)) :
new JwtSecurityToken(
header,
innerJwt,
header.Base64UrlEncode(),
encodedHeader,
Base64UrlEncoder.Encode(wrappedKey),
Base64UrlEncoder.Encode(encryptionResult.IV),
Base64UrlEncoder.Encode(encryptionResult.Ciphertext),
Expand Down Expand Up @@ -842,12 +847,12 @@ public override ClaimsPrincipal ValidateToken(string token, TokenValidationParam
if (token.Length > MaximumTokenSizeInBytes)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(TokenLogMessages.IDX10209, LogHelper.MarkAsNonPII(token.Length), LogHelper.MarkAsNonPII(MaximumTokenSizeInBytes))));

var tokenParts = token.Split(new char[] { '.' }, JwtConstants.MaxJwtSegmentCount + 1);
int tokenPartCount = JwtTokenUtilities.CountJwtTokenPart(token, JwtConstants.MaxJwtSegmentCount + 1);

if (tokenParts.Length != JwtConstants.JwsSegmentCount && tokenParts.Length != JwtConstants.JweSegmentCount)
if (tokenPartCount != JwtConstants.JwsSegmentCount && tokenPartCount != JwtConstants.JweSegmentCount)
throw LogHelper.LogExceptionMessage(new SecurityTokenMalformedException(LogHelper.FormatInvariant(LogMessages.IDX12741, token)));

if (tokenParts.Length == JwtConstants.JweSegmentCount)
if (tokenPartCount == JwtConstants.JweSegmentCount)
{
var jwtToken = ReadJwtToken(token);
var decryptedJwt = DecryptToken(jwtToken, validationParameters);
Expand Down

0 comments on commit 28f89d2

Please sign in to comment.