diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs index 8885a8addb..cd351943d6 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs @@ -17,7 +17,6 @@ namespace Microsoft.IdentityModel.JsonWebTokens /// public class JsonWebToken : SecurityToken { - private char[] _hChars; private ClaimsIdentity _claimsIdentity; private bool _wasClaimsIdentitySet; @@ -470,16 +469,30 @@ private void ReadToken(string encodedJson) throw LogHelper.LogExceptionMessage(new SecurityTokenMalformedException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson))); // right number of dots for JWE - _hChars = encodedJson.ToCharArray(0, Dot1); + ReadOnlyMemory hChars = encodedJson.AsMemory(0, Dot1); // header cannot be empty - if (_hChars.Length == 0) + if (hChars.IsEmpty) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14307, encodedJson))); - HeaderAsciiBytes = Encoding.ASCII.GetBytes(_hChars); + byte[] headerAsciiBytes = new byte[hChars.Length]; +#if NET6_0_OR_GREATER + Encoding.ASCII.GetBytes(hChars.Span, headerAsciiBytes); +#else + unsafe + { + fixed (char* hCharsPtr = hChars.Span) + fixed (byte* headerAsciiBytesPtr = headerAsciiBytes) + { + Encoding.ASCII.GetBytes(hCharsPtr, hChars.Length, headerAsciiBytesPtr, headerAsciiBytes.Length); + } + } +#endif + HeaderAsciiBytes = headerAsciiBytes; + try { - Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(_hChars)); + Header = new JsonClaimSet(Base64UrlEncoder.UnsafeDecode(hChars)); } catch (Exception ex) { @@ -487,8 +500,8 @@ private void ReadToken(string encodedJson) } // dir does not have any key bytes - char[] encryptedKeyBytes = encodedJson.ToCharArray(Dot1 + 1, Dot2 - Dot1 - 1); - if (encryptedKeyBytes.Length != 0) + ReadOnlyMemory encryptedKeyBytes = encodedJson.AsMemory(Dot1 + 1, Dot2 - Dot1 - 1); + if (!encryptedKeyBytes.IsEmpty) { EncryptedKeyBytes = Base64UrlEncoder.UnsafeDecode(encryptedKeyBytes); _encryptedKey = encodedJson.Substring(Dot1 + 1, Dot2 - Dot1 - 1); @@ -498,8 +511,8 @@ private void ReadToken(string encodedJson) _encryptedKey = string.Empty; } - char[] initializationVectorChars = encodedJson.ToCharArray(Dot2 + 1, Dot3 - Dot2 - 1); - if (initializationVectorChars.Length == 0) + ReadOnlyMemory initializationVectorChars = encodedJson.AsMemory(Dot2 + 1, Dot3 - Dot2 - 1); + if (initializationVectorChars.IsEmpty) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14308, encodedJson))); try @@ -511,8 +524,8 @@ private void ReadToken(string encodedJson) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14309, encodedJson, encodedJson), ex)); } - char[] authTagChars = encodedJson.ToCharArray(Dot4 + 1, encodedJson.Length - Dot4 - 1); - if (authTagChars.Length == 0) + ReadOnlyMemory authTagChars = encodedJson.AsMemory(Dot4 + 1); + if (authTagChars.IsEmpty) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14310, encodedJson))); try @@ -524,13 +537,13 @@ private void ReadToken(string encodedJson) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14311, encodedJson, encodedJson), ex)); } - char[] cipherTextBytes = encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1); - if (cipherTextBytes.Length == 0) + ReadOnlyMemory cipherTextBytes = encodedJson.AsMemory(Dot3 + 1, Dot4 - Dot3 - 1); + if (cipherTextBytes.IsEmpty) throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX14306, encodedJson))); try { - CipherTextBytes = Base64UrlEncoder.UnsafeDecode(encodedJson.ToCharArray(Dot3 + 1, Dot4 - Dot3 - 1)); + CipherTextBytes = Base64UrlEncoder.UnsafeDecode(cipherTextBytes); } catch (Exception ex) { diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/Microsoft.IdentityModel.JsonWebTokens.csproj b/src/Microsoft.IdentityModel.JsonWebTokens/Microsoft.IdentityModel.JsonWebTokens.csproj index 9523c38764..52e979fefd 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/Microsoft.IdentityModel.JsonWebTokens.csproj +++ b/src/Microsoft.IdentityModel.JsonWebTokens/Microsoft.IdentityModel.JsonWebTokens.csproj @@ -8,6 +8,7 @@ true Microsoft.IdentityModel.JsonWebTokens .NET;Windows;Authentication;Identity;Json Web Token + true diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs index 27abb2ee10..a96179ca15 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs @@ -2,6 +2,10 @@ // Licensed under the MIT License. using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Microsoft.IdentityModel.Logging; @@ -18,18 +22,6 @@ public static class Base64UrlEncoder private const char base64UrlCharacter62 = '-'; private const char base64UrlCharacter63 = '_'; - /// - /// Encoding table - /// - internal static readonly char[] s_base64Table = - { - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', - 'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z', - '0','1','2','3','4','5','6','7','8','9', - base64UrlCharacter62, - base64UrlCharacter63 - }; - /// /// The following functions perform base64url encoding which differs from regular base64 encoding as follows /// * padding is skipped so the pad character '=' doesn't have to be percent encoded @@ -90,7 +82,8 @@ public static string Encode(byte[] inArray, int offset, int length) int lengthmod3 = length % 3; int limit = offset + (length - lengthmod3); char[] output = new char[(length + 2) / 3 * 4]; - char[] table = s_base64Table; + ReadOnlySpan table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; + int i, j = 0; // takes 3 bytes from inArray and insert 4 bytes into output @@ -100,10 +93,10 @@ public static string Encode(byte[] inArray, int offset, int length) byte d1 = inArray[i + 1]; byte d2 = inArray[i + 2]; - output[j + 0] = table[d0 >> 2]; - output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)]; - output[j + 2] = table[((d1 & 0x0f) << 2) | (d2 >> 6)]; - output[j + 3] = table[d2 & 0x3f]; + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = (char)table[((d1 & 0x0f) << 2) | (d2 >> 6)]; + output[j + 3] = (char)table[d2 & 0x3f]; j += 4; } @@ -117,9 +110,9 @@ public static string Encode(byte[] inArray, int offset, int length) byte d0 = inArray[i]; byte d1 = inArray[i + 1]; - output[j + 0] = table[d0 >> 2]; - output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)]; - output[j + 2] = table[(d1 & 0x0f) << 2]; + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = (char)table[(d1 & 0x0f) << 2]; j += 3; } break; @@ -128,8 +121,8 @@ public static string Encode(byte[] inArray, int offset, int length) { byte d0 = inArray[i]; - output[j + 0] = table[d0 >> 2]; - output[j + 1] = table[(d0 & 0x03) << 4]; + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[(d0 & 0x03) << 4]; j += 2; } break; @@ -168,106 +161,100 @@ internal static string EncodeString(string str) public static byte[] DecodeBytes(string str) { _ = str ?? throw LogHelper.LogExceptionMessage(new ArgumentNullException(nameof(str))); - return UnsafeDecode(str); + return UnsafeDecode(str.AsMemory()); } - internal static unsafe byte[] UnsafeDecode(string str) +#if NET6_0_OR_GREATER + [SkipLocalsInit] +#endif + internal static unsafe byte[] UnsafeDecode(ReadOnlyMemory str) { int mod = str.Length % 4; if (mod == 1) - throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str))); + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str.ToString()))); - bool needReplace = false; + bool needReplace = str.Span.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0; int decodedLength = str.Length + (4 - mod) % 4; - for (int i = 0; i < str.Length; i++) - { - if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63) - { - needReplace = true; - break; - } - } +#if NET6_0_OR_GREATER + // If the incoming chars don't contain any of the base64url characters that need to be replaced, + // and if the incoming chars are of the exact right length, then we'll be able to just pass the + // incoming chars directly to Convert.TryFromBase64Chars. Otherwise, rent an array, copy all the + // data into it, and do whatever fixups are necessary on that copy, then pass that copy into + // Convert.TryFromBase64Chars. - if (needReplace) + const int StackAllocThreshold = 512; + char[] arrayPoolChars = null; + scoped Span charsSpan = default; + scoped ReadOnlySpan source = str.Span; + + if (needReplace || decodedLength != source.Length) { - string decodedString = new(char.MinValue, decodedLength); - fixed (char* dest = decodedString) + charsSpan = decodedLength <= StackAllocThreshold ? + stackalloc char[StackAllocThreshold] : + arrayPoolChars = ArrayPool.Shared.Rent(decodedLength); + charsSpan = charsSpan.Slice(0, decodedLength); + + source.CopyTo(charsSpan); + if (source.Length < charsSpan.Length) { - int i = 0; - for (; i < str.Length; i++) + charsSpan[source.Length] = base64PadCharacter; + if (source.Length + 1 < charsSpan.Length) { - if (str[i] == base64UrlCharacter62) - dest[i] = base64Character62; - else if (str[i] == base64UrlCharacter63) - dest[i] = base64Character63; - else - dest[i] = str[i]; + charsSpan[source.Length + 1] = base64PadCharacter; } - - for (; i < decodedLength; i++) - dest[i] = base64PadCharacter; } - return Convert.FromBase64String(decodedString); - } - else - { - if (decodedLength == str.Length) + if (needReplace) { - return Convert.FromBase64String(str); - } - else - { - string decodedString = new(char.MinValue, decodedLength); - fixed (char* src = str) - fixed (char* dest = decodedString) + int pos; + while ((pos = charsSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0) { - Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2); - - dest[str.Length] = base64PadCharacter; - if (str.Length + 2 == decodedLength) - dest[str.Length + 1] = base64PadCharacter; + charsSpan[pos] = charsSpan[pos] == base64UrlCharacter62 ? base64Character62 : base64Character63; } - - return Convert.FromBase64String(decodedString); } + + source = charsSpan; } - } - internal static unsafe byte[] UnsafeDecode(char[] str) - { - int mod = str.Length % 4; - if (mod == 1) - throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, str))); + byte[] arrayPoolBytes = null; + Span bytesSpan = decodedLength <= StackAllocThreshold ? + stackalloc byte[StackAllocThreshold] : + arrayPoolBytes = ArrayPool.Shared.Rent(decodedLength); - bool needReplace = false; - // the decoded length - int decodedLength = str.Length + (4 - mod) % 4; + bool converted = Convert.TryFromBase64Chars(source, bytesSpan, out int bytesWritten); + Debug.Assert(converted, "Expected TryFromBase64Chars to be successful"); + byte[] result = bytesSpan.Slice(0, bytesWritten).ToArray(); - for (int i = 0; i < str.Length; i++) + if (arrayPoolBytes is not null) { - if (str[i] == base64UrlCharacter62 || str[i] == base64UrlCharacter63) - { - needReplace = true; - break; - } + bytesSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolBytes); } + if (arrayPoolChars is not null) + { + charsSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolChars); + } + + return result; +#else if (needReplace) { + ReadOnlySpan strSpan = str.Span; string decodedString = new(char.MinValue, decodedLength); fixed (char* dest = decodedString) { int i = 0; - for (; i < str.Length; i++) + for (; i < strSpan.Length; i++) { - if (str[i] == base64UrlCharacter62) + if (strSpan[i] == base64UrlCharacter62) dest[i] = base64Character62; - else if (str[i] == base64UrlCharacter63) + else if (strSpan[i] == base64UrlCharacter63) dest[i] = base64Character63; else - dest[i] = str[i]; + dest[i] = strSpan[i]; } for (; i < decodedLength; i++) @@ -280,12 +267,21 @@ internal static unsafe byte[] UnsafeDecode(char[] str) { if (decodedLength == str.Length) { - return Convert.FromBase64CharArray(str, 0, str.Length); + if (MemoryMarshal.TryGetArray(str, out ArraySegment segment)) + { + return Convert.FromBase64CharArray(segment.Array, segment.Offset, segment.Count); + } + else + { + bool gotString = MemoryMarshal.TryGetString(str, out string text, out int start, out int length); + Debug.Assert(gotString, "Expected ReadOnlyMemory to wrap either array or string"); + return Convert.FromBase64String(text.Substring(start, length)); + } } else { string decodedString = new(char.MinValue, decodedLength); - fixed (char* src = str) + fixed (char* src = str.Span) fixed (char* dest = decodedString) { Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2); @@ -298,6 +294,7 @@ internal static unsafe byte[] UnsafeDecode(char[] str) return Convert.FromBase64String(decodedString); } } +#endif } ///