Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce allocation with Base64UrlEncoder #2162

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ namespace Microsoft.IdentityModel.JsonWebTokens
/// </summary>
public class JsonWebToken : SecurityToken
{
private char[] _hChars;
private ClaimsIdentity _claimsIdentity;
private bool _wasClaimsIdentitySet;

Expand Down Expand Up @@ -470,25 +469,39 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new SecurityTokenMalformedException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson)));

// right number of dots for JWE
_hChars = encodedJson.ToCharArray(0, Dot1);
ReadOnlyMemory<char> hChars = encodedJson.AsMemory(0, Dot1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) what does h mean in hChars? "header"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I just kept the original naming :)

stephentoub marked this conversation as resolved.
Show resolved Hide resolved

// header cannot be empty
if (_hChars.Length == 0)
if (hChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14307, encodedJson)));

HeaderAsciiBytes = Encoding.ASCII.GetBytes(_hChars);
byte[] headerAsciiBytes = new byte[hChars.Length];
#if NET6_0_OR_GREATER
Encoding.ASCII.GetBytes(hChars.Span, headerAsciiBytes);
#else
unsafe
{
fixed (char* hCharsPtr = hChars.Span)
fixed (byte* headerAsciiBytesPtr = headerAsciiBytes)
{
Encoding.ASCII.GetBytes(hCharsPtr, hChars.Length, headerAsciiBytesPtr, headerAsciiBytes.Length);
}
}
#endif
HeaderAsciiBytes = headerAsciiBytes;

try
{
Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(_hChars));
Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(hChars));
}
catch (Exception ex)
{
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14102, encodedJson.Substring(0, Dot1), encodedJson), ex));
}

// dir does not have any key bytes
char[] encryptedKeyBytes = encodedJson.ToCharArray(Dot1 + 1, Dot2 - Dot1 - 1);
if (encryptedKeyBytes.Length != 0)
ReadOnlyMemory<char> encryptedKeyBytes = encodedJson.AsMemory(Dot1 + 1, Dot2 - Dot1 - 1);
if (!encryptedKeyBytes.IsEmpty)
{
EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment for existing code, but all the other calls to Base64UrlEncoder.UnsafeDecode have a try-catch around them, but this one doesn't. Do we know why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_encryptedKey = encodedJson.Substring(Dot1 + 1, Dot2 - Dot1 - 1);
Expand All @@ -498,8 +511,8 @@ private void ReadToken(string encodedJson)
_encryptedKey = string.Empty;
}

char[] initializationVectorChars = encodedJson.ToCharArray(Dot2 + 1, Dot3 - Dot2 - 1);
if (initializationVectorChars.Length == 0)
ReadOnlyMemory<char> initializationVectorChars = encodedJson.AsMemory(Dot2 + 1, Dot3 - Dot2 - 1);
if (initializationVectorChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14308, encodedJson)));

try
Expand All @@ -511,8 +524,8 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14309, encodedJson, encodedJson), ex));
}

char[] authTagChars = encodedJson.ToCharArray(Dot4 + 1, encodedJson.Length - Dot4 - 1);
if (authTagChars.Length == 0)
ReadOnlyMemory<char> authTagChars = encodedJson.AsMemory(Dot4 + 1);
if (authTagChars.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson)));

try
Expand All @@ -524,13 +537,13 @@ private void ReadToken(string encodedJson)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14311, encodedJson, encodedJson), ex));
}

char[] cipherTextBytes = encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1);
if (cipherTextBytes.Length == 0)
ReadOnlyMemory<char> cipherTextBytes = encodedJson.AsMemory(Dot3 + 1, Dot4 - Dot3 - 1);
if (cipherTextBytes.IsEmpty)
throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14306, encodedJson)));

try
{
CipherTextBytes = Base64UrlEncoder.UnsafeDecode(encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1));
CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextBytes);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageId>Microsoft.IdentityModel.JsonWebTokens</PackageId>
<PackageTags>.NET;Windows;Authentication;Identity;Json Web Token</PackageTags>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)'=='Debug'">
Expand Down
173 changes: 85 additions & 88 deletions src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.IdentityModel.Logging;

Expand All @@ -18,18 +22,6 @@ public static class Base64UrlEncoder
private const char base64UrlCharacter62 = '-';
private const char base64UrlCharacter63 = '_';

/// <summary>
/// Encoding table
/// </summary>
internal static readonly char[] s_base64Table =
{
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z',
'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
'0','1','2','3','4','5','6','7','8','9',
base64UrlCharacter62,
base64UrlCharacter63
};

/// <summary>
/// The following functions perform base64url encoding which differs from regular base64 encoding as follows
/// * padding is skipped so the pad character '=' doesn't have to be percent encoded
Expand Down Expand Up @@ -90,7 +82,8 @@ public static string Encode(byte[] inArray, int offset, int length)
int lengthmod3 = length % 3;
int limit = offset + (length - lengthmod3);
char[] output = new char[(length + 2) / 3 * 4];
char[] table = s_base64Table;
ReadOnlySpan<byte> table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8;

int i, j = 0;

// takes 3 bytes from inArray and insert 4 bytes into output
Expand All @@ -100,10 +93,10 @@ public static string Encode(byte[] inArray, int offset, int length)
byte d1 = inArray[i + 1];
byte d2 = inArray[i + 2];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = table[((d1 & 0x0f) << 2) | (d2 >> 6)];
output[j + 3] = table[d2 & 0x3f];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = (char)table[((d1 & 0x0f) << 2) | (d2 >> 6)];
output[j + 3] = (char)table[d2 & 0x3f];
j += 4;
}

Expand All @@ -117,9 +110,9 @@ public static string Encode(byte[] inArray, int offset, int length)
byte d0 = inArray[i];
byte d1 = inArray[i + 1];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = table[(d1 & 0x0f) << 2];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)];
output[j + 2] = (char)table[(d1 & 0x0f) << 2];
j += 3;
}
break;
Expand All @@ -128,8 +121,8 @@ public static string Encode(byte[] inArray, int offset, int length)
{
byte d0 = inArray[i];

output[j + 0] = table[d0 >> 2];
output[j + 1] = table[(d0 & 0x03) << 4];
output[j + 0] = (char)table[d0 >> 2];
output[j + 1] = (char)table[(d0 & 0x03) << 4];
j += 2;
}
break;
Expand Down Expand Up @@ -168,106 +161,100 @@ internal static string EncodeString(string str)
public static byte[] DecodeBytes(string str)
{
_ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str)));
return UnsafeDecode(str);
return UnsafeDecode(str.AsMemory());
}

internal static unsafe byte[] UnsafeDecode(string str)
#if NET6_0_OR_GREATER
[SkipLocalsInit]
#endif
internal static unsafe byte[] UnsafeDecode(ReadOnlyMemory<char> str)
{
int mod = str.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str)));
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str.ToString())));

bool needReplace = false;
bool needReplace = str.Span.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0;
int decodedLength = str.Length + (4 - mod) % 4;

for (int i = 0; i < str.Length; i++)
{
if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63)
{
needReplace = true;
break;
}
}
#if NET6_0_OR_GREATER
// If the incoming chars don't contain any of the base64url characters that need to be replaced,
// and if the incoming chars are of the exact right length, then we'll be able to just pass the
// incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the
// data into it, and do whatever fixups are necessary on that copy, then pass that copy into
// Convert.TryFromBase64Chars.

if (needReplace)
const int StackAllocThreshold = 512;
char[] arrayPoolChars = null;
scoped Span<char> charsSpan = default;
scoped ReadOnlySpan<char> source = str.Span;

if (needReplace || decodedLength != source.Length)
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* dest = decodedString)
charsSpan = decodedLength <= StackAllocThreshold ?
stackalloc char[StackAllocThreshold] :
arrayPoolChars = ArrayPool<char>.Shared.Rent(decodedLength);
charsSpan = charsSpan.Slice(0, decodedLength);

source.CopyTo(charsSpan);
if (source.Length < charsSpan.Length)
{
int i = 0;
for (; i < str.Length; i++)
charsSpan[source.Length] = base64PadCharacter;
if (source.Length + 1 < charsSpan.Length)
{
if (str[i] == base64UrlCharacter62)
dest[i] = base64Character62;
else if (str[i] == base64UrlCharacter63)
dest[i] = base64Character63;
else
dest[i] = str[i];
charsSpan[source.Length + 1] = base64PadCharacter;
}

for (; i < decodedLength; i++)
dest[i] = base64PadCharacter;
}

return Convert.FromBase64String(decodedString);
}
else
{
if (decodedLength == str.Length)
if (needReplace)
{
return Convert.FromBase64String(str);
}
else
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* src = str)
fixed (char* dest = decodedString)
int pos;
while ((pos = charsSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0)
{
Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2);

dest[str.Length] = base64PadCharacter;
if (str.Length + 2 == decodedLength)
dest[str.Length + 1] = base64PadCharacter;
charsSpan[pos] = charsSpan[pos] == base64UrlCharacter62 ? base64Character62 : base64Character63;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}

return Convert.FromBase64String(decodedString);
}

source = charsSpan;
}
}

internal static unsafe byte[] UnsafeDecode(char[] str)
{
int mod = str.Length % 4;
if (mod == 1)
throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str)));
byte[] arrayPoolBytes = null;
Span<byte> bytesSpan = decodedLength <= StackAllocThreshold ?
stackalloc byte[StackAllocThreshold] :
arrayPoolBytes = ArrayPool<byte>.Shared.Rent(decodedLength);

bool needReplace = false;
// the decoded length
int decodedLength = str.Length + (4 - mod) % 4;
bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten);
Debug.Assert(converted, "Expected TryFromBase64Chars to be successful");
byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray();

for (int i = 0; i < str.Length; i++)
if (arrayPoolBytes is not null)
{
if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63)
{
needReplace = true;
break;
}
bytesSpan.Clear();
ArrayPool<byte>.Shared.Return(arrayPoolBytes);
}

if (arrayPoolChars is not null)
{
charsSpan.Clear();
ArrayPool<char>.Shared.Return(arrayPoolChars);
}

return result;
#else
if (needReplace)
{
ReadOnlySpan<char> strSpan = str.Span;
string decodedString = new(char.MinValue, decodedLength);
fixed (char* dest = decodedString)
{
int i = 0;
for (; i < str.Length; i++)
for (; i < strSpan.Length; i++)
{
if (str[i] == base64UrlCharacter62)
if (strSpan[i] == base64UrlCharacter62)
dest[i] = base64Character62;
else if (str[i] == base64UrlCharacter63)
else if (strSpan[i] == base64UrlCharacter63)
dest[i] = base64Character63;
else
dest[i] = str[i];
dest[i] = strSpan[i];
}

for (; i < decodedLength; i++)
Expand All @@ -280,12 +267,21 @@ internal static unsafe byte[] UnsafeDecode(char[] str)
{
if (decodedLength == str.Length)
{
return Convert.FromBase64CharArray(str, 0, str.Length);
if (MemoryMarshal.TryGetArray(str, out ArraySegment<char> segment))
{
return Convert.FromBase64CharArray(segment.Array, segment.Offset, segment.Count);
}
else
{
bool gotString = MemoryMarshal.TryGetString(str, out string text, out int start, out int length);
Debug.Assert(gotString, "Expected ReadOnlyMemory to wrap either array or string");
return Convert.FromBase64String(text.Substring(start, length));
}
}
else
{
string decodedString = new(char.MinValue, decodedLength);
fixed (char* src = str)
fixed (char* src = str.Span)
fixed (char* dest = decodedString)
{
Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2);
Expand All @@ -298,6 +294,7 @@ internal static unsafe byte[] UnsafeDecode(char[] str)
return Convert.FromBase64String(decodedString);
}
}
#endif
}

/// <summary>
Expand Down