diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj index b82cae25f0..f30681b251 100644 --- a/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj +++ b/benchmark/Microsoft.IdentityModel.Benchmarks/Microsoft.IdentityModel.Benchmarks.csproj @@ -48,9 +48,4 @@ - - - true - - diff --git a/build/common.props b/build/common.props index 41c636cc18..73d9e7c851 100644 --- a/build/common.props +++ b/build/common.props @@ -31,11 +31,6 @@ true - - - true - - false 8.0.0 diff --git a/build/commonTest.props b/build/commonTest.props index 10172f3297..91df851ece 100644 --- a/build/commonTest.props +++ b/build/commonTest.props @@ -51,11 +51,6 @@ - - - true - - true diff --git a/build/dependencies.props b/build/dependencies.props index 9cb0568fb6..63ea05785f 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -12,7 +12,6 @@ 4.5.5 4.5.0 8.0.5 - 9.0.0 diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs index eb8c98e036..8ebea86b7e 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs @@ -3,7 +3,9 @@ using System; using System.Buffers; +#if NET9_0_OR_GREATER using System.Buffers.Text; +#endif using System.Text; using Microsoft.IdentityModel.Logging; @@ -21,6 +23,8 @@ public static class Base64UrlEncoder private const char Base64PadCharacter = '='; private const char Base64Character62 = '+'; private const char Base64Character63 = '/'; + private const char Base64UrlCharacter62 = '-'; + private const char Base64UrlCharacter63 = '_'; /// /// Performs base64url encoding, which differs from regular base64 encoding as follows: @@ -98,9 +102,17 @@ public static string Encode(byte[] inArray, int offset, int length) LogHelper.MarkAsNonPII(inArray.Length)))); #pragma warning restore CA2208 // Instantiate argument exceptions correctly +#if NET9_0_OR_GREATER return Base64Url.EncodeToString(inArray.AsSpan().Slice(offset, length)); +#else + char[] destination = new char[(inArray.Length + 2) / 3 * 4]; + int j = Encode(inArray.AsSpan().Slice(offset, length), destination.AsSpan()); + + return new string(destination, 0, j); +#endif } +#if NET9_0_OR_GREATER /// /// Populates a with the base64url encoded representation of a of bytes. /// @@ -108,7 +120,68 @@ public static string Encode(byte[] inArray, int offset, int length) /// The span of characters to write the encoded output. /// The number of characters written to the output span. public static int Encode(ReadOnlySpan inArray, Span output) => Base64Url.EncodeToChars(inArray, output); +#else + /// + /// Populates a with the base64url encoded representation of a of bytes. + /// + /// A read-only span of bytes to encode. + /// The span of characters to write the encoded output. + /// The number of characters written to the output span. + public static int Encode(ReadOnlySpan inArray, Span output) + { + int lengthmod3 = inArray.Length % 3; + int limit = (inArray.Length - lengthmod3); + ReadOnlySpan table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; + + int i, j = 0; + + // takes 3 bytes from inArray and insert 4 bytes into output + for (i = 0; i < limit; i += 3) + { + byte d0 = inArray[i]; + byte d1 = inArray[i + 1]; + byte d2 = inArray[i + 2]; + + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = (char)table[((d1 & 0x0f) << 2) | (d2 >> 6)]; + output[j + 3] = (char)table[d2 & 0x3f]; + j += 4; + } + + //Where we left off before + i = limit; + switch (lengthmod3) + { + case 2: + { + byte d0 = inArray[i]; + byte d1 = inArray[i + 1]; + + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = (char)table[(d1 & 0x0f) << 2]; + j += 3; + } + break; + case 1: + { + byte d0 = inArray[i]; + + output[j + 0] = (char)table[d0 >> 2]; + output[j + 1] = (char)table[(d0 & 0x03) << 4]; + j += 2; + } + break; + + //default or case 0: no further operations are needed. + } + + return j; + } + +#endif /// /// Converts the specified base64url encoded string to UTF-8 bytes. /// @@ -123,6 +196,8 @@ public static byte[] DecodeBytes(string str) #if NETCOREAPP [SkipLocalsInit] #endif + +#if NET9_0_OR_GREATER internal static byte[] Decode(ReadOnlySpan strSpan) { int upperBound = Base64Url.GetMaxDecodedLength(strSpan.Length); @@ -144,33 +219,37 @@ internal static byte[] Decode(ReadOnlySpan strSpan) ArrayPool.Shared.Return(rented, true); } } - -#if !NET8_0_OR_GREATER - private static bool IsOnlyValidBase64Chars(ReadOnlySpan strSpan) +#else + internal static byte[] Decode(ReadOnlySpan strSpan) { - foreach (char c in strSpan) - if (!char.IsDigit(c) && !char.IsLetter(c) && c != Base64Character62 && c != Base64Character63 && c != Base64PadCharacter) - return false; + int mod = strSpan.Length % 4; + if (mod == 1) + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); - return true; + bool needReplace = strSpan.IndexOfAny(Base64UrlCharacter62, Base64UrlCharacter63) >= 0; + int decodedLength = strSpan.Length + (4 - mod) % 4; +#if NET6_0_OR_GREATER + Span output = new byte[decodedLength]; + int length = Decode(strSpan, output, needReplace, decodedLength); + return output.Slice(0, length).ToArray(); +#else + return UnsafeDecode(strSpan, needReplace, decodedLength); +#endif } - #endif + #if NETCOREAPP [SkipLocalsInit] #endif + +#if NET9_0_OR_GREATER internal static int Decode(ReadOnlySpan strSpan, Span output) { OperationStatus status = Base64Url.DecodeFromChars(strSpan, output, out _, out int bytesWritten); if (status == OperationStatus.Done) return bytesWritten; - if (status == OperationStatus.InvalidData && -#if NET8_0_OR_GREATER - !Base64.IsValid(strSpan)) -#else - !IsOnlyValidBase64Chars(strSpan)) -#endif + if (status == OperationStatus.InvalidData && !Base64.IsValid(strSpan)) throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); int mod = strSpan.Length % 4; @@ -180,8 +259,24 @@ internal static int Decode(ReadOnlySpan strSpan, Span output) return Decode(strSpan, output, decodedLength); } +#else + internal static void Decode(ReadOnlySpan strSpan, Span output) + { + int mod = strSpan.Length % 4; + if (mod == 1) + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); + bool needReplace = strSpan.IndexOfAny(Base64UrlCharacter62, Base64UrlCharacter63) >= 0; + int decodedLength = strSpan.Length + (4 - mod) % 4; +#if NET6_0_OR_GREATER + Decode(strSpan, output, needReplace, decodedLength); +#else + Decode(strSpan, output, needReplace, decodedLength); +#endif + } -#if NETCOREAPP +#endif + +#if NET9_0_OR_GREATER [SkipLocalsInit] private static int Decode(ReadOnlySpan strSpan, Span output, int decodedLength) { @@ -254,33 +349,144 @@ private static ReadOnlySpan HandlePadding(ReadOnlySpan source, Span< return charsSpan; } -#else - private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, int decodedLength) +#elif NET6_0_OR_GREATER + [SkipLocalsInit] + private static int Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) { - if (decodedLength == strSpan.Length) + // If the incoming chars don't contain any of the base64url characters that need to be replaced, + // and if the incoming chars are of the exact right length, then we'll be able to just pass the + // incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the + // data into it, and do whatever fixups are necessary on that copy, then pass that copy into + // DecodeFromUtf8InPlace. + + const int StackAllocThreshold = 512; + char[] arrayPoolChars = null; + scoped Span charsSpan = default; + scoped ReadOnlySpan source = strSpan; + + if (needReplace || decodedLength != source.Length) + { + charsSpan = decodedLength <= StackAllocThreshold ? + stackalloc char[StackAllocThreshold] : + arrayPoolChars = ArrayPool.Shared.Rent(decodedLength); + charsSpan = charsSpan.Slice(0, decodedLength); + + source = HandlePaddingAndReplace(source, charsSpan, needReplace); + } + + byte[] arrayPoolBytes = null; + Span bytesSpan = decodedLength <= StackAllocThreshold ? + stackalloc byte[StackAllocThreshold] : + arrayPoolBytes = ArrayPool.Shared.Rent(decodedLength); + + int length = Encoding.UTF8.GetBytes(source, bytesSpan); + Span utf8Span = bytesSpan.Slice(0, length); + try { - return Convert.FromBase64CharArray(strSpan.ToArray(), 0, strSpan.Length); + OperationStatus status = System.Buffers.Text.Base64.DecodeFromUtf8InPlace(utf8Span, out int bytesWritten); + if (status != OperationStatus.Done) + throw LogHelper.LogExceptionMessage(new FormatException(LogHelper.FormatInvariant(LogMessages.IDX10400, strSpan.ToString()))); + + utf8Span.Slice(0, bytesWritten).CopyTo(output); + + return bytesWritten; } + finally + { + if (arrayPoolBytes is not null) + { + bytesSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolBytes); + } - string decodedString = new(char.MinValue, decodedLength); - fixed (char* src = strSpan) - fixed (char* dest = decodedString) + if (arrayPoolChars is not null) + { + charsSpan.Clear(); + ArrayPool.Shared.Return(arrayPoolChars); + } + } + } + + private static ReadOnlySpan HandlePaddingAndReplace(ReadOnlySpan source, Span charsSpan, bool needReplace) + { + source.CopyTo(charsSpan); + if (source.Length < charsSpan.Length) { - Buffer.MemoryCopy(src, dest, strSpan.Length * 2, strSpan.Length * 2); + charsSpan[source.Length] = Base64PadCharacter; + if (source.Length + 1 < charsSpan.Length) + { + charsSpan[source.Length + 1] = Base64PadCharacter; + } + } - dest[strSpan.Length] = Base64PadCharacter; - if (strSpan.Length + 2 == decodedLength) - dest[strSpan.Length + 1] = Base64PadCharacter; + if (needReplace) + { + Span remaining = charsSpan; + int pos; + while ((pos = remaining.IndexOfAny(Base64UrlCharacter62, Base64UrlCharacter63)) >= 0) + { + remaining[pos] = (remaining[pos] == Base64UrlCharacter62) ? Base64Character62 : Base64Character63; + remaining = remaining.Slice(pos + 1); + } } - return Convert.FromBase64String(decodedString); + return charsSpan; } - private static int Decode(ReadOnlySpan strSpan, Span output, int decodedLength) +#else + private static unsafe byte[] UnsafeDecode(ReadOnlySpan strSpan, bool needReplace, int decodedLength) + { + if (needReplace) + { + string decodedString = new(char.MinValue, decodedLength); + fixed (char* dest = decodedString) + { + int i = 0; + for (; i < strSpan.Length; i++) + { + if (strSpan[i] == Base64UrlCharacter62) + dest[i] = Base64Character62; + else if (strSpan[i] == Base64UrlCharacter63) + dest[i] = Base64Character63; + else + dest[i] = strSpan[i]; + } + + for (; i < decodedLength; i++) + dest[i] = Base64PadCharacter; + } + + return Convert.FromBase64String(decodedString); + } + else + { + if (decodedLength == strSpan.Length) + { + return Convert.FromBase64CharArray(strSpan.ToArray(), 0, strSpan.Length); + } + else + { + string decodedString = new(char.MinValue, decodedLength); + fixed (char* src = strSpan) + fixed (char* dest = decodedString) + { + Buffer.MemoryCopy(src, dest, strSpan.Length * 2, strSpan.Length * 2); + + dest[strSpan.Length] = Base64PadCharacter; + if (strSpan.Length + 2 == decodedLength) + dest[strSpan.Length + 1] = Base64PadCharacter; + } + + return Convert.FromBase64String(decodedString); + } + } + } + + private static void Decode(ReadOnlySpan strSpan, Span output, bool needReplace, int decodedLength) { - byte[] result = UnsafeDecode(strSpan, decodedLength); + byte[] result = UnsafeDecode(strSpan, needReplace, decodedLength); result.CopyTo(output); - return result.Length; + } #endif diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs index f2d6cb4f55..c26d432917 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs @@ -3,13 +3,41 @@ using System; using System.Buffers; +#if NET9_0_OR_GREATER using System.Buffers.Text; +#endif using Microsoft.IdentityModel.Logging; namespace Microsoft.IdentityModel.Tokens { + /// + /// For Non-Net9.0 Targets: Base64 encode/decode implementation for as per https://tools.ietf.org/html/rfc4648#section-5. + /// For Net9.0 Targets: Uses System.Buffers.Text.Base64Url to perform the encoding/decoding. + /// Uses ArrayPool[T] to minimize memory usage. + /// internal static class Base64UrlEncoding { + private const uint IntA = 'A'; + private const uint IntZ = 'Z'; + private const uint Inta = 'a'; + private const uint Intz = 'z'; + private const uint Int0 = '0'; + private const uint Int9 = '9'; + private const uint IntEq = '='; + private const uint IntPlus = '+'; + private const uint IntMinus = '-'; + private const uint IntSlash = '/'; + private const uint IntUnderscore = '_'; + + private static readonly char[] Base64Table = + { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', + 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', + 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', '-', '_', + }; + /// /// Decodes a base64url encoded string into a byte array. /// @@ -155,6 +183,8 @@ public static T Decode( } } + +#if NET9_0_OR_GREATER /// /// Decodes a Base64Url encoded substring of a string into a byte array. /// @@ -164,6 +194,89 @@ public static T Decode( /// The byte array to place the decoded results into. internal static void Decode(ReadOnlySpan input, int offset, int length, byte[] output) => Base64Url.DecodeFromChars(input.Slice(offset, length), output); +#else + /// + /// Changes from Base64UrlEncoder implementation: + /// 1. Padding is optional. + /// 2. '+' and '-' are treated the same. + /// 3. '/' and '_' are treated the same. + /// + internal static void Decode(ReadOnlySpan input, int offset, int length, byte[] output) + { + int outputpos = 0; + uint curblock = 0x000000FFu; + for (int i = offset; i < (offset + length); i++) + { + uint cur = input[i]; + if (cur >= IntA && cur <= IntZ) + { + cur -= IntA; + } + else if (cur >= Inta && cur <= Intz) + { + cur = (cur - Inta) + 26u; + } + else if (cur >= Int0 && cur <= Int9) + { + cur = (cur - Int0) + 52u; + } + else if (cur == IntPlus || cur == IntMinus) + { + cur = 62u; + } + else if (cur == IntSlash || cur == IntUnderscore) + { + cur = 63u; + } + else if (cur == IntEq) + { + continue; + } + else + { + throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + LogHelper.FormatInvariant( + LogMessages.IDX10820, + LogHelper.MarkAsNonPII(cur), + input.ToString()))); + } + + curblock = (curblock << 6) | cur; + + // check if 4 characters have been read, based on number of shifts. + if ((0xFF000000u & curblock) == 0xFF000000u) + { + output[outputpos++] = (byte)(curblock >> 16); + output[outputpos++] = (byte)(curblock >> 8); + output[outputpos++] = (byte)curblock; + curblock = 0x000000FFu; + } + } + + // Handle spill over characters. This accounts for case where padding character is not present. + if (curblock != 0x000000FFu) + { + if ((0x03FC0000u & curblock) == 0x03FC0000u) + { + // shifted 3 times, 1 padding character, 2 output characters + curblock <<= 6; + output[outputpos++] = (byte)(curblock >> 16); + output[outputpos++] = (byte)(curblock >> 8); + } + else if ((0x000FF000u & curblock) == 0x000FF000u) + { + // shifted 2 times, 2 padding character, 1 output character + curblock <<= 12; + output[outputpos++] = (byte)(curblock >> 16); + } + else + { + throw LogHelper.LogExceptionMessage(new ArgumentException( + LogHelper.FormatInvariant(LogMessages.IDX10821, input.ToString()))); + } + } + } +#endif /// /// Encodes a byte array into a base64url encoded string. @@ -220,7 +333,19 @@ public static string Encode(byte[] input, int offset, int length) LogHelper.MarkAsNonPII(input.Length)))); #pragma warning restore CA2208 // Instantiate argument exceptions correctly +#if NET9_0_OR_GREATER return Base64Url.EncodeToString(input.AsSpan().Slice(offset, length)); +#else + int outputsize = length % 3; + if (outputsize > 0) + outputsize++; + + outputsize += (length / 3) * 4; + + char[] output = new char[outputsize]; + WriteEncodedOutput(input, offset, length, output); + return new string(output); +#endif } /// @@ -284,5 +409,42 @@ internal static int ValidateAndGetOutputSize(ReadOnlySpan strSpan, int off outputSize += (effectiveLength / 4) * 3; return outputSize; } + +#if !NET9_0_OR_GREATER + private static void WriteEncodedOutput(byte[] inputBytes, int offset, int length, Span output) + { + uint curBlock = 0x000000FFu; + int outputPointer = 0; + + for (int i = offset; i < offset + length; i++) + { + curBlock = (curBlock << 8) | inputBytes[i]; + + if ((curBlock & 0xFF000000u) == 0xFF000000u) + { + output[outputPointer++] = Base64Table[(curBlock & 0x00FC0000u) >> 18]; + output[outputPointer++] = Base64Table[(curBlock & 0x00030000u | curBlock & 0x0000F000u) >> 12]; + output[outputPointer++] = Base64Table[(curBlock & 0x00000F00u | curBlock & 0x000000C0u) >> 6]; + output[outputPointer++] = Base64Table[curBlock & 0x0000003Fu]; + + curBlock = 0x000000FFu; + } + } + + if ((curBlock & 0x00FF0000u) == 0x00FF0000u) + { + // 2 shifts, 3 output characters. + output[outputPointer++] = Base64Table[(curBlock & 0x0000FC00u) >> 10]; + output[outputPointer++] = Base64Table[(curBlock & 0x000003F0u) >> 4]; + output[outputPointer++] = Base64Table[(curBlock & 0x0000000Fu) << 2]; + } + else if ((curBlock & 0x0000FF00u) == 0x0000FF00u) + { + // 1 shift, 2 output characters. + output[outputPointer++] = Base64Table[(curBlock & 0x000000FCu) >> 2]; + output[outputPointer++] = Base64Table[(curBlock & 0x00000003u) << 4]; + } + } +#endif } } diff --git a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs index 163d8e8d73..00cfbaeb4a 100644 --- a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs +++ b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs @@ -274,6 +274,11 @@ internal static class LogMessages public const string IDX10902 = "IDX10902: Exception caught while removing expired items: '{0}', Exception: '{1}'"; public const string IDX10906 = "IDX10906: Exception caught while compacting items: '{0}', Exception: '{1}'"; + + // Base64UrlEncoding + public const string IDX10820 = "IDX10820: Invalid character found in Base64UrlEncoding. Character: '{0}', Encoding: '{1}'."; + public const string IDX10821 = "IDX10821: Incorrect padding detected in Base64UrlEncoding. Encoding: '{0}'."; + // Crypto Errors public const string IDX11000 = "IDX11000: Cannot create EcdhKeyExchangeProvider. '{0}'\'s Curve '{1}' does not match with '{2}'\'s curve '{3}'."; public const string IDX11001 = "IDX11001: Cannot generate KDF. '{0}':'{1}' and '{2}':'{3}' must be different."; diff --git a/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj b/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj index 25eb2bbc34..138feaaa92 100644 --- a/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj +++ b/src/Microsoft.IdentityModel.Tokens/Microsoft.IdentityModel.Tokens.csproj @@ -65,10 +65,6 @@ - - - -