Skip to content

Commit

Permalink
Update JsonWebToken for child classes.
Browse files Browse the repository at this point in the history
Add another claims dictionary to JsonWebToken.Payload to hold string claims as UTF8 ReadOnlyMemory<byte>.

Using and reading from the original token as a Span and saving claim value indices instead of creating Memor<byte> for each claim.

Add delegate for reading properties instead of an overload.
  • Loading branch information
pmaytak authored and HP712 committed Sep 9, 2024
1 parent 5ba278b commit 8b61ba8
Show file tree
Hide file tree
Showing 10 changed files with 483 additions and 153 deletions.
103 changes: 100 additions & 3 deletions src/Microsoft.IdentityModel.JsonWebTokens/Json/JsonClaimSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,37 @@ internal class JsonClaimSet

internal object _claimsLock = new();
internal readonly Dictionary<string, object> _jsonClaims;

#if NET8_0_OR_GREATER
internal readonly Memory<byte> _tokenAsMemory;
#endif

private List<Claim> _claims;

internal JsonClaimSet()
{
_jsonClaims = new Dictionary<string, object>();
_jsonClaims = [];

#if NET8_0_OR_GREATER
_tokenAsMemory = Memory<byte>.Empty;
#endif
}

internal JsonClaimSet(Dictionary<string, object> jsonClaims)
{
_jsonClaims = jsonClaims;
}

#if NET8_0_OR_GREATER
internal JsonClaimSet(
Dictionary<string, object> jsonClaims,
Memory<byte> tokenAsMemory)
{
_jsonClaims = jsonClaims;
_tokenAsMemory = tokenAsMemory;
}
#endif

internal List<Claim> Claims(string issuer)
{
if (_claims == null)
Expand All @@ -49,8 +68,20 @@ internal List<Claim> CreateClaims(string issuer)
{
var claims = new List<Claim>(_jsonClaims.Count);
foreach (KeyValuePair<string, object> kvp in _jsonClaims)
{
CreateClaimFromObject(claims, kvp.Key, kvp.Value, issuer);

#if NET8_0_OR_GREATER
if (kvp.Value is ValuePosition position)
{
if (position.IsEscaped)
EscapeStringBytesInPlace(position);

string value = System.Text.Encoding.UTF8.GetString(_tokenAsMemory.Slice(position.StartIndex, position.Length).Span);
claims.Add(new Claim(kvp.Key, value, ClaimValueTypes.String, issuer, issuer));
}
#endif
}
return claims;
}

Expand Down Expand Up @@ -167,12 +198,56 @@ internal string GetStringValue(string key)
if (obj == null)
return null;

#if NET8_0_OR_GREATER
if (obj is ValuePosition position)
{
if (position.IsEscaped)
EscapeStringBytesInPlace(position);

return System.Text.Encoding.UTF8.GetString(_tokenAsMemory.Slice(position.StartIndex, position.Length).Span);
}
#endif
return obj.ToString();
}

return string.Empty;
}

#if NET8_0_OR_GREATER
// Similar to GetStringValue but returns the bytes directly.
internal ReadOnlySpan<byte> GetStringBytesValue(string key)
{
if (_jsonClaims.TryGetValue(key, out object obj))
{
if (obj == null)
return null;

if (obj is ValuePosition position)
{
if (position.IsEscaped)
EscapeStringBytesInPlace(position);

return _tokenAsMemory.Slice(position.StartIndex, position.Length).Span;
}
}

return [];
}

/// <summary>
/// Unescapes the bytes of a string claim value in-place in the token bytes Memory instance.
/// After escaping, updates the length of the claim value to reflect the unescaped bytes.
/// </summary>
/// <remarks>The start position and length provided to the Utf8JsonReader has to be adjusted to include double quotes.</remarks>
/// <param name="position">Position of the claim value.</param>
private void EscapeStringBytesInPlace(ValuePosition position)
{
var reader = new Utf8JsonReader(_tokenAsMemory.Span.Slice(position.StartIndex - 1, position.Length + 2));
reader.Read();
position.Length = reader.CopyString(_tokenAsMemory.Span.Slice(position.StartIndex, position.Length));
position.IsEscaped = false;
}
#endif

internal DateTime GetDateTime(string key)
{
long l = GetValue<long>(key, false, out bool found);
Expand Down Expand Up @@ -235,8 +310,19 @@ internal T GetValue<T>(string key, bool throwEx, out bool found)
if (list.Count == 1)
return (T)((object)(list[0]));
}
#if NET8_0_OR_GREATER
else if (obj is ValuePosition position)
{
if (position.IsEscaped)
EscapeStringBytesInPlace(position);

return (T)(object)System.Text.Encoding.UTF8.GetString(_tokenAsMemory.Slice(position.StartIndex, position.Length).Span);
}
#endif
else
{
return (T)((object)obj.ToString());
}
}
else if (typeof(T) == typeof(bool))
{
Expand Down Expand Up @@ -425,13 +511,24 @@ internal bool TryGetClaim(string key, string issuer, out Claim claim)
/// <returns><see langword="true"/> if the key was found; otherwise, <see langword="false"/>.</returns>
internal bool TryGetValue<T>(string key, out T value)
{
#if NET8_0_OR_GREATER
if (typeof(T) == typeof(string))
{
var span = GetStringBytesValue(key);
if (!span.IsEmpty)
{
value = (T)(object)System.Text.Encoding.UTF8.GetString(span);
return true;
}
}
#endif
value = GetValue<T>(key, false, out bool found);
return found;
}

internal bool HasClaim(string claimName)
{
return _jsonClaims.TryGetValue(claimName, out _);
return _jsonClaims.ContainsKey(claimName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ public partial class JsonWebToken
{
internal JsonClaimSet CreateHeaderClaimSet(byte[] bytes)
{
return CreateHeaderClaimSet(bytes.AsSpan());
return CreateHeaderClaimSet(bytes.AsMemory());
}

internal JsonClaimSet CreateHeaderClaimSet(byte[] bytes, int length)
{
return CreateHeaderClaimSet(bytes.AsSpan(0, length));
return CreateHeaderClaimSet(bytes.AsMemory(0, length));
}

internal JsonClaimSet CreateHeaderClaimSet(ReadOnlySpan<byte> byteSpan)
internal JsonClaimSet CreateHeaderClaimSet(Memory<byte> tokenHeaderAsMemory)
{
Utf8JsonReader reader = new(byteSpan);
Utf8JsonReader reader = new(tokenHeaderAsMemory.Span);
if (!JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.StartObject, true))
throw LogHelper.LogExceptionMessage(
new JsonException(
Expand All @@ -36,46 +36,13 @@ internal JsonClaimSet CreateHeaderClaimSet(ReadOnlySpan<byte> byteSpan)
LogHelper.MarkAsNonPII(reader.CurrentDepth),
LogHelper.MarkAsNonPII(reader.BytesConsumed))));

Dictionary<string, object> claims = new();
Dictionary<string, object> claims = [];
while (true)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Alg))
{
_alg = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Alg, ClassName, true);
claims[JwtHeaderParameterNames.Alg] = _alg;
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Cty))
{
_cty = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Cty, ClassName, true);
claims[JwtHeaderParameterNames.Cty] = _cty;
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Kid))
{
_kid = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Kid, ClassName, true);
claims[JwtHeaderParameterNames.Kid] = _kid;
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Typ))
{
_typ = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Typ, ClassName, true);
claims[JwtHeaderParameterNames.Typ] = _typ;
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.X5t))
{
_x5t = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.X5t, ClassName, true);
claims[JwtHeaderParameterNames.X5t] = _x5t;
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Zip))
{
_zip = JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Zip, ClassName, true);
claims[JwtHeaderParameterNames.Zip] = _zip;
}
else
{
string propertyName = reader.GetString();
claims[propertyName] = JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, propertyName, JsonClaimSet.ClassName, true);
}
string claimName = reader.GetString();
claims[claimName] = ReadTokenHeaderValueDelegate(ref reader, claimName);
}
// We read a JsonTokenType.StartObject above, exiting and positioning reader at next token.
else if (JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.EndObject, false))
Expand All @@ -84,7 +51,74 @@ internal JsonClaimSet CreateHeaderClaimSet(ReadOnlySpan<byte> byteSpan)
break;
};

#if NET8_0_OR_GREATER
return new JsonClaimSet(claims, tokenHeaderAsMemory);
#else
return new JsonClaimSet(claims);
#endif
}

/// <summary>
/// Reads and saves the value of the header claim from the reader.
/// </summary>
/// <param name="reader">The reader over the JWT.</param>
/// <param name="claimName">The claim at the current position of the reader.</param>
/// <returns>A claim that was read.</returns>
public static object ReadTokenHeaderValue(ref Utf8JsonReader reader, string claimName)
{
#if NET8_0_OR_GREATER
if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Alg))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.Alg, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Cty))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.Cty, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Kid))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.Kid, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Typ))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.Typ, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.X5t))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.X5t, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Zip))
{
return JsonSerializerPrimitives.ReadStringBytesLocation(ref reader, JwtHeaderParameterNames.Zip, ClassName, true);
}
#else
if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Alg))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Alg, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Cty))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Cty, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Kid))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Kid, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Typ))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Typ, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.X5t))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.X5t, ClassName, true);
}
else if (reader.ValueTextEquals(JwtHeaderUtf8Bytes.Zip))
{
return JsonSerializerPrimitives.ReadString(ref reader, JwtHeaderParameterNames.Zip, ClassName, true);
}
#endif

return JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, claimName, JsonClaimSet.ClassName, true);
}
}
}
Loading

0 comments on commit 8b61ba8

Please sign in to comment.