From 438a1f16c0684a0cfe1ee1bb739846967a5bece6 Mon Sep 17 00:00:00 2001 From: Brett White Date: Tue, 23 Jul 2024 16:06:52 -0700 Subject: [PATCH] Use new Base64Url API --- .../Base64UrlEncoderTests.cs | 86 +++++++ .../identitymodel.benchmarks.yml | 5 + build/dependencies.props | 1 + .../JsonWebToken.cs | 13 +- .../Base64UrlEncoder.cs | 223 ++++++------------ .../Base64UrlEncoding.cs | 185 ++------------- .../LogMessages.cs | 4 - .../Microsoft.IdentityModel.Tokens.csproj | 1 + .../Base64UrlEncodingTests.cs | 18 +- 9 files changed, 216 insertions(+), 320 deletions(-) create mode 100644 benchmark/Microsoft.IdentityModel.Benchmarks/Base64UrlEncoderTests.cs diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/Base64UrlEncoderTests.cs b/benchmark/Microsoft.IdentityModel.Benchmarks/Base64UrlEncoderTests.cs new file mode 100644 index 0000000000..30a56e4a69 --- /dev/null +++ b/benchmark/Microsoft.IdentityModel.Benchmarks/Base64UrlEncoderTests.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers.Text; +using BenchmarkDotNet.Attributes; +using Microsoft.IdentityModel.Tokens; + +namespace Microsoft.IdentityModel.Benchmarks +{ + // dotnet run -c release -f net9.0 --filter Microsoft.IdentityModel.Benchmarks.Base64UrlEncoderTests* + + [MemoryDiagnoser] + public class Base64UrlEncoderTests + { + private const string base64UrlEncodedString = "V1aOrL0rTJZYRxPHmMZdTWR0-ilwg9V9iEoKSn3CYl5vmBNqsM0x4VtvRbK8nCmnCIc2__QE92D4vQDR8AQ2j_BljJUkNY51VrbZ1wBar_7X2NktF_AQLkqDmwuagjhONR9_MIVysq36EAqxoHAwHHJx87XrrOkPDD8kiQ2uZEgPgK-4o02hhjsETU7KWiOKg4nKlLUU2YwuW2ZQxVubPfEv5SrW8BDgvNwPseyXfKrznhNAQHgwUX6sh1lTBm-cQdujkNsG62DeJSA2o9A_IhpKOuyQpaNda6U8jbBh3FGZhmFAm6yxNag3b6jAVlxphRNDvlm6UprgoFbvzcuH8W5ZH60LjNxsSKLH8W3gHIc7jhDA0vH2T8Nf2HEqFmqcsGr6aNm86ilWg1tchS_DlFPWqu8Wm3EEHTSJcd7BxMTvr9syRLICmhVsfHwdgMy1WfKklnyGJ_RT3kvbfCPQ2sSRMiOqCkdwCUECu-CcxS4CiIanlWnIpllmBov6vawcR6o6gmcFuqxhw2rp3815glnF7jNkmr7hsd0DPQ7qRUOHlGkF8_Sgretbgpb61y8a8DlVLlb7nBBQbTFif-lBAH4gfWWeNF9A3RFPQ8e8UKghJ7u_4ua9W_Lk_xpDkyGDXrkAzTYLxOGujRaWexOpwWSOKsXgIqXa94px0HAUIAVwP2Gy_gWcVz47ayedXh1Tcqb3K1hDlzZt4XK6O9eu-lAgy6gBltSrkntumDB-XEkxRabh8FNMln_LeEh_TgwWX4iVBR1-VD-VJw1e_aypVWj_E178TjCeb6Lc9pKD_r2VAieZpVp0c15g3vxznBWPD5mviHnK_NbSiccodSfpzGJbUsBuvKvhK4EFSw4_YlWJFlEXj3XYtiqO60crVynlEEqegLncI6RrjWe8WEfXEm_yeiglH5I-asU5sl0pBdLRdeg1xo1SZfR-CtgJ0dliwGkPDE6HcyGqhddMbIze_5I8ZazQ31PQaShhXtdH3K_cWXe4WhpR-_qYTrwib89ux2zZxePCkb_RXyvd09hv1J1kkmTf9f7q1xXfiBw49Iun90tJaOMru6PeL3Ayixj4d2C-rnwS43jcRJJ_SBiRgpBQo3Gg893UkxY2l2prQa-zU9GdbwlfDF9Htijxm75SuoxOldhTFDcpw6QqKjt1116gfkmgg16hXjvNhV8sCqxmHdKoIM6EOKVy5MAIJcg_-wbAVhbJQ205udIPb49GY1yDePieu2eQa6TU8Pn66YK5Kl4K6kCmOY6NpDdhDk6BwyJ6Z9wz2nF8OwF2mDKpMdP2nkFnq8iq2z9o7s7HwIP8pbr99kvMlw"; + // Add padding + private static readonly string base64EncodedString = base64UrlEncodedString.Replace('-', '+').Replace('_', '/') + "=="; + // Add "padding" without adding special characters + private static readonly string base64NoSpecialCharsEncodedString = base64UrlEncodedString.Replace('-', 'A').Replace('_', 'B') + "ab"; + // Add padding as only special characters (Base64-encoded but could be decoded with Base64Url API) + private static readonly string base64NoSpecialCharsExceptPaddingEncodedString = base64UrlEncodedString.Replace('-', 'A').Replace('_', 'B') + "=="; + private static readonly string decodedString = Base64UrlEncoder.Decode(base64UrlEncodedString); + private static readonly byte[] decodedBytes = Base64UrlEncoder.DecodeBytes(base64UrlEncodedString); + + [Benchmark] + public void Decode_String_Base64Url() => Base64UrlEncoder.Decode(base64UrlEncodedString); + + [Benchmark] + public void Decode_Span_Base64Url() => Base64UrlEncoder.Decode(base64UrlEncodedString.AsSpan()); + + [Benchmark] + public void DecodeBytes_Base64Url() => Base64UrlEncoder.DecodeBytes(base64UrlEncodedString); + + [Benchmark] + public void Decode_Span_Output_Base64Url() => Base64UrlEncoder.Decode(base64UrlEncodedString.AsSpan(), new byte[Base64.GetMaxDecodedFromUtf8Length(base64UrlEncodedString.Length + 2)]); + + [Benchmark] + public void Decode_String_Base64() => Base64UrlEncoder.Decode(base64EncodedString); + + [Benchmark] + public void Decode_Span_Base64() => Base64UrlEncoder.Decode(base64EncodedString.AsSpan()); + + [Benchmark] + public void DecodeBytes_Base64() => Base64UrlEncoder.DecodeBytes(base64EncodedString); + + [Benchmark] + public void Decode_Span_Output_Base64() => Base64UrlEncoder.Decode(base64EncodedString.AsSpan(), new byte[Base64.GetMaxDecodedFromUtf8Length(base64EncodedString.Length + 2)]); + + [Benchmark] + public void Decode_String_Base64NoSpecialChars() => Base64UrlEncoder.Decode(base64NoSpecialCharsEncodedString); + + [Benchmark] + public void Decode_Span_Base64NoSpecialChars() => Base64UrlEncoder.Decode(base64NoSpecialCharsEncodedString.AsSpan()); + + [Benchmark] + public void DecodeBytes_Base64NoSpecialChars() => Base64UrlEncoder.DecodeBytes(base64NoSpecialCharsEncodedString); + + [Benchmark] + public void Decode_Span_Output_Base64NoSpecialChars() => Base64UrlEncoder.Decode(base64NoSpecialCharsEncodedString.AsSpan(), new byte[Base64.GetMaxDecodedFromUtf8Length(base64NoSpecialCharsEncodedString.Length + 2)]); + + [Benchmark] + public void Decode_String_Base64NoSpecialCharsExceptPadding() => Base64UrlEncoder.Decode(base64NoSpecialCharsExceptPaddingEncodedString); + + [Benchmark] + public void Decode_Span_Base64NoSpecialCharsExceptPadding() => Base64UrlEncoder.Decode(base64NoSpecialCharsExceptPaddingEncodedString.AsSpan()); + + [Benchmark] + public void DecodeBytes_Base64NoSpecialCharsExceptPadding() => Base64UrlEncoder.DecodeBytes(base64NoSpecialCharsExceptPaddingEncodedString); + + [Benchmark] + public void Decode_Span_Output_Base64NoSpecialCharsExceptPadding() => Base64UrlEncoder.Decode(base64NoSpecialCharsExceptPaddingEncodedString.AsSpan(), new byte[Base64.GetMaxDecodedFromUtf8Length(base64NoSpecialCharsExceptPaddingEncodedString.Length + 2)]); + + [Benchmark] + public void Encode_String_Base64Url() => Base64UrlEncoder.Encode(decodedString); + + [Benchmark] + public void Encode_Bytes_Base64Url() => Base64UrlEncoder.Encode(decodedBytes); + + [Benchmark] + public void Encode_Span_Base64Url() => Base64UrlEncoder.Encode(decodedBytes, new char[Base64.GetMaxEncodedToUtf8Length(decodedBytes.Length)]); + + [Benchmark] + public void Encode_Bytes_Offset_Length_Base64Url() => Base64UrlEncoder.Encode(decodedBytes, decodedBytes.Length / 2, decodedBytes.Length / 2 - 10); + } +} diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/identitymodel.benchmarks.yml b/benchmark/Microsoft.IdentityModel.Benchmarks/identitymodel.benchmarks.yml index 8997e8a3eb..ee33073e46 100644 --- a/benchmark/Microsoft.IdentityModel.Benchmarks/identitymodel.benchmarks.yml +++ b/benchmark/Microsoft.IdentityModel.Benchmarks/identitymodel.benchmarks.yml @@ -57,3 +57,8 @@ scenarios: variables: filterArg: "*ValidateTokenAsyncTests*" + Base64UrlEncoderTests: + application: + job: benchmarks + variables: + filterArg: "*Base64UrlEncoderTests*" diff --git a/build/dependencies.props b/build/dependencies.props index 79eac92835..90408aa8b1 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -11,6 +11,7 @@ 4.5.5 4.5.0 8.0.4 + 9.0.0-rc.1.24431.7 diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs index 48af759d86..9db92fc295 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebToken.cs @@ -569,15 +569,22 @@ internal JsonClaimSet CreateClaimSet(ReadOnlySpan strSpan, int startIndex, { int outputSize = Base64UrlEncoding.ValidateAndGetOutputSize(strSpan, startIndex, length); - byte[] output = ArrayPool.Shared.Rent(outputSize); + byte[] rented = null; + + const int MaxStackallocThreshold = 256; + Span output = outputSize <= MaxStackallocThreshold + ? stackalloc byte[outputSize] + : (rented = ArrayPool.Shared.Rent(outputSize)); + try { Base64UrlEncoder.Decode(strSpan.Slice(startIndex, length), output); - return createHeaderClaimSet ? CreateHeaderClaimSet(output.AsSpan()) : CreatePayloadClaimSet(output.AsSpan()); + return createHeaderClaimSet ? CreateHeaderClaimSet(output) : CreatePayloadClaimSet(output); } finally { - ArrayPool.Shared.Return(output, true); + if (rented is not null) + ArrayPool.Shared.Return(rented, true); } } diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs index 7fcd84cb1a..eb8c98e036 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Buffers.Text; using System.Text; using Microsoft.IdentityModel.Logging; @@ -17,11 +18,9 @@ namespace Microsoft.IdentityModel.Tokens /// public static class Base64UrlEncoder { - private const char base64PadCharacter = '='; - private const char base64Character62 = '+'; - private const char base64Character63 = '/'; - private const char base64UrlCharacter62 = '-'; - private const char base64UrlCharacter63 = '_'; + private const char Base64PadCharacter = '='; + private const char Base64Character62 = '+'; + private const char Base64Character63 = '/'; /// /// Performs base64url encoding, which differs from regular base64 encoding as follows: @@ -99,10 +98,7 @@ public static string Encode(byte[] inArray, int offset, int length) LogHelper.MarkAsNonPII(inArray.Length)))); #pragma warning restore CA2208 // Instantiate argument exceptions correctly - char[] destination = new char[(inArray.Length + 2) / 3 * 4]; - int j = Encode(inArray.AsSpan().Slice(offset, length), destination.AsSpan()); - - return new string(destination, 0, j); + return Base64Url.EncodeToString(inArray.AsSpan().Slice(offset, length)); } /// @@ -111,60 +107,7 @@ public static string Encode(byte[] inArray, int offset, int length) /// A read-only span of bytes to encode. /// The span of characters to write the encoded output. /// The number of characters written to the output span. - public static int Encode(ReadOnlySpan inArray, Span output) - { - int lengthmod3 = inArray.Length % 3; - int limit = (inArray.Length - lengthmod3); - ReadOnlySpan table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; - - int i, j = 0; - - // takes 3 bytes from inArray and insert 4 bytes into output - for (i = 0; i < limit; i += 3) - { - byte d0 = inArray[i]; - byte d1 = inArray[i + 1]; - byte d2 = inArray[i + 2]; - - output[j + 0] = (char)table[d0 >> 2]; - output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; - output[j + 2] = (char)table[((d1 & 0x0f) << 2) | (d2 >> 6)]; - output[j + 3] = (char)table[d2 & 0x3f]; - j += 4; - } - - //Where we left off before - i = limit; - - switch (lengthmod3) - { - case 2: - { - byte d0 = inArray[i]; - byte d1 = inArray[i + 1]; - - output[j + 0] = (char)table[d0 >> 2]; - output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; - output[j + 2] = (char)table[(d1 & 0x0f) << 2]; - j += 3; - } - break; - - case 1: - { - byte d0 = inArray[i]; - - output[j + 0] = (char)table[d0 >> 2]; - output[j + 1] = (char)table[(d0 & 0x03) << 4]; - j += 2; - } - break; - - //default or case 0: no further operations are needed. - } - - return j; - } + public static int Encode(ReadOnlySpan inArray, Span output) => Base64Url.EncodeToChars(inArray, output); /// /// Converts the specified base64url encoded string to UTF-8 bytes. @@ -177,47 +120,70 @@ public static byte[] DecodeBytes(string str) return Decode(str.AsSpan()); } +#if NETCOREAPP + [SkipLocalsInit] +#endif internal static byte[] Decode(ReadOnlySpan strSpan) { - int mod = strSpan.Length % 4; - if (mod == 1) - throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); + int upperBound = Base64Url.GetMaxDecodedLength(strSpan.Length); + byte[] rented = null; - bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0; - int decodedLength = strSpan.Length + (4 - mod) % 4; + const int MaxStackallocThreshold = 256; + Span destination = upperBound <= MaxStackallocThreshold + ? stackalloc byte[upperBound] + : (rented = ArrayPool.Shared.Rent(upperBound)); -#if NET6_0_OR_GREATER + try + { + int bytesWritten = Decode(strSpan, destination); + return destination.Slice(0, bytesWritten).ToArray(); + } + finally + { + if (rented is not null) + ArrayPool.Shared.Return(rented, true); + } + } - Span output = new byte[decodedLength]; +#if !NET8_0_OR_GREATER + private static bool IsOnlyValidBase64Chars(ReadOnlySpan strSpan) + { + foreach (char c in strSpan) + if (!char.IsDigit(c) && !char.IsLetter(c) && c != Base64Character62 && c != Base64Character63 && c != Base64PadCharacter) + return false; - int length = Decode(strSpan, output, needReplace, decodedLength); + return true; + } + +#endif +#if NETCOREAPP + [SkipLocalsInit] +#endif + internal static int Decode(ReadOnlySpan strSpan, Span output) + { + OperationStatus status = Base64Url.DecodeFromChars(strSpan, output, out _, out int bytesWritten); + if (status == OperationStatus.Done) + return bytesWritten; - return output.Slice(0, length).ToArray(); + if (status == OperationStatus.InvalidData && +#if NET8_0_OR_GREATER + !Base64.IsValid(strSpan)) #else - return UnsafeDecode(strSpan, needReplace, decodedLength); + !IsOnlyValidBase64Chars(strSpan)) #endif - } + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); - internal static void Decode(ReadOnlySpan strSpan, Span output) - { int mod = strSpan.Length % 4; if (mod == 1) throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); - - bool needReplace = strSpan.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63) >= 0; int decodedLength = strSpan.Length + (4 - mod) % 4; -#if NET6_0_OR_GREATER - Decode(strSpan, output, needReplace, decodedLength); -#else - Decode(strSpan, output, needReplace, decodedLength); -#endif + return Decode(strSpan, output, decodedLength); } -#if NET6_0_OR_GREATER - +#if NETCOREAPP [SkipLocalsInit] - private static int Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) + private static int Decode(ReadOnlySpan strSpan, Span output, int decodedLength) { // If the incoming chars don't contain any of the base64url characters that need to be replaced, // and if the incoming chars are of the exact right length, then we'll be able to just pass the @@ -230,14 +196,14 @@ private static int Decode(ReadOnlySpan strSpan, Span output, bool ne scoped Span charsSpan = default; scoped ReadOnlySpan source = strSpan; - if (needReplace || decodedLength != source.Length) + if (decodedLength != source.Length) { charsSpan = decodedLength <= StackAllocThreshold ? stackalloc char[StackAllocThreshold] : arrayPoolChars = ArrayPool.Shared.Rent(decodedLength); charsSpan = charsSpan.Slice(0, decodedLength); - source = HandlePaddingAndReplace(source, charsSpan, needReplace); + source = HandlePadding(source, charsSpan); } byte[] arrayPoolBytes = null; @@ -250,7 +216,7 @@ private static int Decode(ReadOnlySpan strSpan, Span output, bool ne try { - OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten); + OperationStatus status = Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten); if (status != OperationStatus.Done) throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); @@ -274,86 +240,47 @@ private static int Decode(ReadOnlySpan strSpan, Span output, bool ne } } - private static ReadOnlySpan HandlePaddingAndReplace(ReadOnlySpan source, Span charsSpan, bool needReplace) + private static ReadOnlySpan HandlePadding(ReadOnlySpan source, Span charsSpan) { source.CopyTo(charsSpan); if (source.Length < charsSpan.Length) { - charsSpan[source.Length] = base64PadCharacter; + charsSpan[source.Length] = Base64PadCharacter; if (source.Length + 1 < charsSpan.Length) { - charsSpan[source.Length + 1] = base64PadCharacter; - } - } - - if (needReplace) - { - Span remaining = charsSpan; - int pos; - while ((pos = remaining.IndexOfAny(base64UrlCharacter62, base64UrlCharacter63)) >= 0) - { - remaining[pos] = (remaining[pos] == base64UrlCharacter62) ? base64Character62 : base64Character63; - remaining = remaining.Slice(pos + 1); + charsSpan[source.Length + 1] = Base64PadCharacter; } } return charsSpan; } - #else - - private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, bool needReplace, int decodedLength) + private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, int decodedLength) { - if (needReplace) + if (decodedLength == strSpan.Length) { - string decodedString = new(char.MinValue, decodedLength); - fixed (char* dest = decodedString) - { - int i = 0; - for (; i < strSpan.Length; i++) - { - if (strSpan[i] == base64UrlCharacter62) - dest[i] = base64Character62; - else if (strSpan[i] == base64UrlCharacter63) - dest[i] = base64Character63; - else - dest[i] = strSpan[i]; - } - - for (; i < decodedLength; i++) - dest[i] = base64PadCharacter; - } - - return Convert.FromBase64String(decodedString); + return Convert.FromBase64CharArray(strSpan.ToArray(), 0, strSpan.Length); } - else + + string decodedString = new(char.MinValue, decodedLength); + fixed (char* src = strSpan) + fixed (char* dest = decodedString) { - if (decodedLength == strSpan.Length) - { - return Convert.FromBase64CharArray(strSpan.ToArray(), 0, strSpan.Length); - } - else - { - string decodedString = new(char.MinValue, decodedLength); - fixed (char* src = strSpan) - fixed (char* dest = decodedString) - { - Buffer.MemoryCopy(src, dest, strSpan.Length * 2, strSpan.Length * 2); - - dest[strSpan.Length] = base64PadCharacter; - if (strSpan.Length + 2 == decodedLength) - dest[strSpan.Length + 1] = base64PadCharacter; - } - - return Convert.FromBase64String(decodedString); - } + Buffer.MemoryCopy(src, dest, strSpan.Length * 2, strSpan.Length * 2); + + dest[strSpan.Length] = Base64PadCharacter; + if (strSpan.Length + 2 == decodedLength) + dest[strSpan.Length + 1] = Base64PadCharacter; } + + return Convert.FromBase64String(decodedString); } - private static void Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) + private static int Decode(ReadOnlySpan strSpan, Span output, int decodedLength) { - byte[] result = UnsafeDecode(strSpan, needReplace, decodedLength); + byte[] result = UnsafeDecode(strSpan, decodedLength); result.CopyTo(output); + return result.Length; } #endif diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs index 558aaf54bf..f2d6cb4f55 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs @@ -3,37 +3,13 @@ using System; using System.Buffers; +using System.Buffers.Text; using Microsoft.IdentityModel.Logging; namespace Microsoft.IdentityModel.Tokens { - /// - /// Base64 encode/decode implementation for as per https://tools.ietf.org/html/rfc4648#section-5. - /// Uses ArrayPool[T] to minimize memory usage. - /// internal static class Base64UrlEncoding { - private const uint IntA = 'A'; - private const uint IntZ = 'Z'; - private const uint Inta = 'a'; - private const uint Intz = 'z'; - private const uint Int0 = '0'; - private const uint Int9 = '9'; - private const uint IntEq = '='; - private const uint IntPlus = '+'; - private const uint IntMinus = '-'; - private const uint IntSlash = '/'; - private const uint IntUnderscore = '_'; - - private static readonly char[] Base64Table = - { - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', - 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', - 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', - 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', - '8', '9', '-', '_', - }; - /// /// Decodes a base64url encoded string into a byte array. /// @@ -58,8 +34,8 @@ public static byte[] Decode(string input, int offset, int length) _ = input ?? throw LogHelper.LogArgumentNullException(nameof(input)); ReadOnlySpan inputSpan = input.AsSpan(); - int outputsize = ValidateAndGetOutputSize(inputSpan, offset, length); - byte[] output = new byte[outputsize]; + int outputSize = ValidateAndGetOutputSize(inputSpan, offset, length); + byte[] output = new byte[outputSize]; Decode(inputSpan, offset, length, output); return output; } @@ -85,16 +61,17 @@ public static T Decode(string input, int offset, int length, TX argx, Fun _ = action ?? throw new ArgumentNullException(nameof(action)); ReadOnlySpan inputSpan = input.AsSpan(); - int outputsize = ValidateAndGetOutputSize(inputSpan, offset, length); - byte[] output = ArrayPool.Shared.Rent(outputsize); + int outputSize = ValidateAndGetOutputSize(inputSpan, offset, length); + byte[] output = ArrayPool.Shared.Rent(outputSize); + try { Decode(inputSpan, offset, length, output); - return action(output, outputsize, argx); + return action(output, outputSize, argx); } finally { - ArrayPool.Shared.Return(output); + ArrayPool.Shared.Return(output, true); } } @@ -118,16 +95,17 @@ public static T Decode(string input, int offset, int length, Func inputSpan = input.AsSpan(); - int outputsize = ValidateAndGetOutputSize(inputSpan, offset, length); - byte[] output = ArrayPool.Shared.Rent(outputsize); + int outputSize = ValidateAndGetOutputSize(inputSpan, offset, length); + byte[] output = ArrayPool.Shared.Rent(outputSize); + try { Decode(inputSpan, offset, length, output); - return action(output, outputsize); + return action(output, outputSize); } finally { - ArrayPool.Shared.Return(output); + ArrayPool.Shared.Return(output, true); } } @@ -163,16 +141,17 @@ public static T Decode( _ = action ?? throw LogHelper.LogArgumentNullException(nameof(action)); ReadOnlySpan inputSpan = input.AsSpan(); - int outputsize = ValidateAndGetOutputSize(inputSpan, offset, length); - byte[] output = ArrayPool.Shared.Rent(outputsize); + int outputSize = ValidateAndGetOutputSize(inputSpan, offset, length); + byte[] output = ArrayPool.Shared.Rent(outputSize); + try { Decode(inputSpan, offset, length, output); - return action(output, outputsize, argx, argy, argz); + return action(output, outputSize, argx, argy, argz); } finally { - ArrayPool.Shared.Return(output); + ArrayPool.Shared.Return(output, true); } } @@ -183,87 +162,8 @@ public static T Decode( /// The index of the character in to start decoding from. /// The number of characters beginning from to decode. /// The byte array to place the decoded results into. - /// - /// Changes from Base64UrlEncoder implementation: - /// 1. Padding is optional. - /// 2. '+' and '-' are treated the same. - /// 3. '/' and '_' are treated the same. - /// - internal static void Decode(ReadOnlySpan input, int offset, int length, byte[] output) - { - int outputpos = 0; - uint curblock = 0x000000FFu; - for (int i = offset; i < (offset + length); i++) - { - uint cur = input[i]; - if (cur >= IntA && cur <= IntZ) - { - cur -= IntA; - } - else if (cur >= Inta && cur <= Intz) - { - cur = (cur - Inta) + 26u; - } - else if (cur >= Int0 && cur <= Int9) - { - cur = (cur - Int0) + 52u; - } - else if (cur == IntPlus || cur == IntMinus) - { - cur = 62u; - } - else if (cur == IntSlash || cur == IntUnderscore) - { - cur = 63u; - } - else if (cur == IntEq) - { - continue; - } - else - { - throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( - LogHelper.FormatInvariant( - LogMessages.IDX10820, - LogHelper.MarkAsNonPII(cur), - input.ToString()))); - } - - curblock = (curblock << 6) | cur; - - // check if 4 characters have been read, based on number of shifts. - if ((0xFF000000u & curblock) == 0xFF000000u) - { - output[outputpos++] = (byte)(curblock >> 16); - output[outputpos++] = (byte)(curblock >> 8); - output[outputpos++] = (byte)curblock; - curblock = 0x000000FFu; - } - } - - // Handle spill over characters. This accounts for case where padding character is not present. - if (curblock != 0x000000FFu) - { - if ((0x03FC0000u & curblock) == 0x03FC0000u) - { - // shifted 3 times, 1 padding character, 2 output characters - curblock <<= 6; - output[outputpos++] = (byte)(curblock >> 16); - output[outputpos++] = (byte)(curblock >> 8); - } - else if ((0x000FF000u & curblock) == 0x000FF000u) - { - // shifted 2 times, 2 padding character, 1 output character - curblock <<= 12; - output[outputpos++] = (byte)(curblock >> 16); - } - else - { - throw LogHelper.LogExceptionMessage(new ArgumentException( - LogHelper.FormatInvariant(LogMessages.IDX10821, input.ToString()))); - } - } - } + internal static void Decode(ReadOnlySpan input, int offset, int length, byte[] output) => + Base64Url.DecodeFromChars(input.Slice(offset, length), output); /// /// Encodes a byte array into a base64url encoded string. @@ -320,15 +220,7 @@ public static string Encode(byte[] input, int offset, int length) LogHelper.MarkAsNonPII(input.Length)))); #pragma warning restore CA2208 // Instantiate argument exceptions correctly - int outputsize = length % 3; - if (outputsize > 0) - outputsize++; - - outputsize += (length / 3) * 4; - - char[] output = new char[outputsize]; - WriteEncodedOutput(input, offset, length, output); - return new string(output); + return Base64Url.EncodeToString(input.AsSpan().Slice(offset, length)); } /// @@ -392,40 +284,5 @@ internal static int ValidateAndGetOutputSize(ReadOnlySpan strSpan, int off outputSize += (effectiveLength / 4) * 3; return outputSize; } - - private static void WriteEncodedOutput(byte[] inputBytes, int offset, int length, Span output) - { - uint curBlock = 0x000000FFu; - int outputPointer = 0; - - for (int i = offset; i < offset + length; i++) - { - curBlock = (curBlock << 8) | inputBytes[i]; - - if ((curBlock & 0xFF000000u) == 0xFF000000u) - { - output[outputPointer++] = Base64Table[(curBlock & 0x00FC0000u) >> 18]; - output[outputPointer++] = Base64Table[(curBlock & 0x00030000u | curBlock & 0x0000F000u) >> 12]; - output[outputPointer++] = Base64Table[(curBlock & 0x00000F00u | curBlock & 0x000000C0u) >> 6]; - output[outputPointer++] = Base64Table[curBlock & 0x0000003Fu]; - - curBlock = 0x000000FFu; - } - } - - if ((curBlock & 0x00FF0000u) == 0x00FF0000u) - { - // 2 shifts, 3 output characters. - output[outputPointer++] = Base64Table[(curBlock & 0x0000FC00u) >> 10]; - output[outputPointer++] = Base64Table[(curBlock & 0x000003F0u) >> 4]; - output[outputPointer++] = Base64Table[(curBlock & 0x0000000Fu) << 2]; - } - else if ((curBlock & 0x0000FF00u) == 0x0000FF00u) - { - // 1 shift, 2 output characters. - output[outputPointer++] = Base64Table[(curBlock & 0x000000FCu) >> 2]; - output[outputPointer++] = Base64Table[(curBlock & 0x00000003u) << 4]; - } - } } } diff --git a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs index 8a282dd2fc..03488753a6 100644 --- a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs +++ b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs @@ -252,10 +252,6 @@ internal static class LogMessages public const string IDX10815 = "IDX10815: Depth of JSON: '{0}' exceeds max depth of '{1}'."; public const string IDX10816 = "IDX10816: Decompressing would result in a token with a size greater than allowed. Maximum size allowed: '{0}'."; - // Base64UrlEncoding - public const string IDX10820 = "IDX10820: Invalid character found in Base64UrlEncoding. Character: '{0}', Encoding: '{1}'."; - public const string IDX10821 = "IDX10821: Incorrect padding detected in Base64UrlEncoding. Encoding: '{0}'."; - //EventBasedLRUCache errors public const string IDX10900 = "IDX10900: EventBasedLRUCache._eventQueue encountered an error while processing a cache operation. Exception '{0}'."; public const string IDX10901 = "IDX10901: CryptoProviderCacheOptions.SizeLimit must be greater than 10. Value: '{0}'"; diff --git a/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj b/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj index 6636af5a56..2ac35ea55d 100644 --- a/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj +++ b/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj @@ -39,6 +39,7 @@ + diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs index 71883d4822..b220959074 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs @@ -33,6 +33,9 @@ public void EncodeTests(Base64UrlEncoderTheoryData theoryData) string encodingString = Base64UrlEncoding.Encode(theoryData.Bytes); string encodingBytesUsingOffset = Base64UrlEncoding.Encode(theoryData.OffsetBytes, theoryData.Offset, theoryData.Length); + byte[] decodedBytes = theoryData.Bytes?.Length == 0 ? Array.Empty() : Base64UrlEncoding.Decode(encodingString); + const string randomPadding = "RANDOMPADDING"; + byte[] decodedBytes2 = theoryData.Bytes?.Length == 0 ? Array.Empty() : Base64UrlEncoding.Decode(randomPadding + encodingString + randomPadding, randomPadding.Length, encodingString.Length); theoryData.ExpectedException.ProcessNoException(context); @@ -46,7 +49,8 @@ public void EncodeTests(Base64UrlEncoderTheoryData theoryData) IdentityComparer.AreStringsEqual(encodingBytesUsingOffset, encodingString, "encodingBytesUsingOffset", "encodingString", context); IdentityComparer.AreStringsEqual(theoryData.ExpectedValue, encodingString, "theoryData.ExpectedValue", "encodingString", context); - + IdentityComparer.AreEqual(theoryData.Bytes, decodedBytes, context); + IdentityComparer.AreEqual(theoryData.Bytes, decodedBytes2, context); } catch (Exception ex) { @@ -327,5 +331,17 @@ public void ValidateAndGetOutputSizeTests() actualOutputSize = Base64UrlEncoding.ValidateAndGetOutputSize("abc=".AsSpan(), 0, 4); Assert.Equal(2, actualOutputSize); } + + [Fact] + public void EncodeDecodeExceptionTests() + { + Assert.Throws(static () => Base64UrlEncoding.Decode(null)); + Assert.Throws(static () => Base64UrlEncoding.Decode(null, 0, 0)); + Assert.Throws(static () => Base64UrlEncoding.Encode(null)); + Assert.Throws(static () => Base64UrlEncoding.Encode(null, 0, 0)); + Assert.Throws(static () => Base64UrlEncoding.Decode("abc", 0, 0, null)); + Assert.Throws(static () => Base64UrlEncoding.Decode("abc", 0, 0, null, null)); + Assert.Throws(static () => Base64UrlEncoding.Decode(null, 0, 0, null, null, null, null)); + } } }