diff --git a/csharp/Svix.Tests/WebhookTest.cs b/csharp/Svix.Tests/WebhookTest.cs index 707bf6e0d..c7889db17 100644 --- a/csharp/Svix.Tests/WebhookTest.cs +++ b/csharp/Svix.Tests/WebhookTest.cs @@ -207,5 +207,21 @@ public void VerifyWebhookSignWorks() var signature = wh.Sign(msgId, timestamp, payload); Assert.Equal(signature, expected); } + + [Theory] + [InlineData(1024 * 1024, "v1,txpEUxqWZJ5nteTnymUVa+7C4NHpBeXJ6CsBAW0c3/A=")] + [InlineData(256 * 1024, "v1,Mw4Pe2WgApuT7NSqnSq0PQPV9gbLdggCl9B865x5Xh0=")] + [InlineData(256 * 1024 - 1, "v1,NTItnqyWMoFhk2Bbe5V/BMdcFRWgcSQaTgqLHX7FC7c=")] + public void VerifyWebhookSignWithLargePayloadWorks(int payloadSize, string expected) + { + var key = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw"; + var msgId = "msg_p5jXN8AQM9LWM0D4loKWxJek"; + var timestamp = DateTimeOffset.FromUnixTimeSeconds(1614265330); + var payload = new string('a', payloadSize); + + var wh = new Webhook(key); + var signature = wh.Sign(msgId, timestamp, payload); + Assert.Equal(signature, expected); + } } } diff --git a/csharp/Svix/Utils.cs b/csharp/Svix/Utils.cs index 1a9584d84..207ab4199 100644 --- a/csharp/Svix/Utils.cs +++ b/csharp/Svix/Utils.cs @@ -5,22 +5,11 @@ namespace Svix { internal static class Utils { - // Borrowed from Stripe-dotnet https://github.com/stripe/stripe-dotnet/blob/7b62c461d7c0cf2c9e06dce5e564b374a9d232e0/src/Stripe.net/Infrastructure/StringUtils.cs#L30 // basically identical to SecureCompare from Rails::ActiveSupport used in our ruby lib [MethodImpl(MethodImplOptions.NoOptimization)] - public static bool SecureCompare(string a, string b) + public static bool SecureCompare(ReadOnlySpan a, ReadOnlySpan b) { - if (a == null) - { - throw new ArgumentNullException(nameof(a)); - } - - if (b == null) - { - throw new ArgumentNullException(nameof(b)); - } - if (a.Length != b.Length) { return false; diff --git a/csharp/Svix/Webhook.cs b/csharp/Svix/Webhook.cs index 2fcb41dc3..0f35e634e 100644 --- a/csharp/Svix/Webhook.cs +++ b/csharp/Svix/Webhook.cs @@ -1,5 +1,7 @@ using Svix.Exceptions; using System; +using System.Buffers; +using System.Buffers.Text; using System.Net; using System.Security.Cryptography; using System.Text; @@ -17,14 +19,19 @@ public sealed class Webhook internal const string UNBRANDED_SIGNATURE_HEADER_KEY = "webhook-signature"; internal const string UNBRANDED_TIMESTAMP_HEADER_KEY = "webhook-timestamp"; + private const int SIGNATURE_LENGTH_BYTES = HMACSHA256.HashSizeInBytes; + private const int SIGNATURE_LENGTH_BASE64 = 48; + private const int SIGNATURE_LENGTH_STRING = 56; private const int TOLERANCE_IN_SECONDS = 60 * 5; - private static string prefix = "whsec_"; - private byte[] key; + private const int MAX_STACKALLOC = 1024 * 256; + private const string PREFIX = "whsec_"; + + private readonly byte[] key; public Webhook(string key) { - if (key.StartsWith(prefix)) + if (key.StartsWith(PREFIX)) { - key = key.Substring(prefix.Length); + key = key.Substring(PREFIX.Length); } this.key = Convert.FromBase64String(key); @@ -35,63 +42,84 @@ public Webhook(byte[] key) this.key = key; } - public void Verify(string payload, WebHeaderCollection headers) + public void Verify(ReadOnlySpan payload, WebHeaderCollection headers) { - ArgumentNullException.ThrowIfNull(headers); + if (payload == null) + { + throw new ArgumentNullException(nameof(payload)); + } + if (headers == null) + { + throw new ArgumentNullException(nameof(headers)); + } Verify(payload, headers.Get); } - public void Verify(string payload, Func headersProvider) + public void Verify(ReadOnlySpan payload, Func headersProvider) { - ArgumentNullException.ThrowIfNull(payload); - ArgumentNullException.ThrowIfNull(headersProvider); + if (payload == null) + { + throw new ArgumentNullException(nameof(payload)); + } + if (headersProvider == null) + { + throw new ArgumentNullException(nameof(headersProvider)); + } - string msgId = headersProvider(SVIX_ID_HEADER_KEY); - string msgSignature = headersProvider(SVIX_SIGNATURE_HEADER_KEY); - string msgTimestamp = headersProvider(SVIX_TIMESTAMP_HEADER_KEY); - - if (String.IsNullOrEmpty(msgId) || String.IsNullOrEmpty(msgSignature) || String.IsNullOrEmpty(msgTimestamp)) + ReadOnlySpan msgId = headersProvider(SVIX_ID_HEADER_KEY); + ReadOnlySpan msgTimestamp = headersProvider(SVIX_TIMESTAMP_HEADER_KEY); + ReadOnlySpan msgSignature = headersProvider(SVIX_SIGNATURE_HEADER_KEY); + + if (msgId.IsEmpty || msgSignature.IsEmpty || msgTimestamp.IsEmpty) { msgId = headersProvider(UNBRANDED_ID_HEADER_KEY); msgSignature = headersProvider(UNBRANDED_SIGNATURE_HEADER_KEY); msgTimestamp = headersProvider(UNBRANDED_TIMESTAMP_HEADER_KEY); - if (String.IsNullOrEmpty(msgId) || String.IsNullOrEmpty(msgSignature) || String.IsNullOrEmpty(msgTimestamp)) + if (msgId.IsEmpty || msgSignature.IsEmpty || msgTimestamp.IsEmpty) { throw new WebhookVerificationException("Missing Required Headers"); } } - var timestamp = Webhook.VerifyTimestamp(msgTimestamp); + Webhook.VerifyTimestamp(msgTimestamp); - var signature = this.Sign(msgId, timestamp, payload); - var expectedSignature = signature.Split(',')[1]; + Span expectedSignature = stackalloc char[SIGNATURE_LENGTH_STRING]; + CalculateSignature(msgId, msgTimestamp, payload, expectedSignature, out var charsWritten); + expectedSignature = expectedSignature.Slice(0, charsWritten); - var passedSignatures = msgSignature.Split(' '); - foreach (string versionedSignature in passedSignatures) + var signaturePtr = msgSignature; + var spaceIndex = signaturePtr.IndexOf(' '); + do { - var parts = versionedSignature.Split(','); - if (parts.Length < 2) + var versionedSignature = spaceIndex < 0 + ? msgSignature : signaturePtr.Slice(0, spaceIndex); + + signaturePtr = signaturePtr.Slice(spaceIndex + 1); + spaceIndex = signaturePtr.IndexOf(' '); + + var commaIndex = versionedSignature.IndexOf(','); + if (commaIndex < 0) { throw new WebhookVerificationException("Invalid Signature Headers"); } - var version = parts[0]; - var passedSignature = parts[1]; - - if (version != "v1") + var version = versionedSignature.Slice(0, commaIndex); + if (!version.Equals("v1", StringComparison.InvariantCulture)) { continue; } + var passedSignature = versionedSignature.Slice(commaIndex + 1); if (Utils.SecureCompare(expectedSignature, passedSignature)) { return; } - } + while(spaceIndex >= 0); + throw new WebhookVerificationException("No matching signature found"); } - private static DateTimeOffset VerifyTimestamp(string timestampHeader) + private static void VerifyTimestamp(ReadOnlySpan timestampHeader) { DateTimeOffset timestamp; var now = DateTimeOffset.UtcNow; @@ -105,26 +133,70 @@ private static DateTimeOffset VerifyTimestamp(string timestampHeader) throw new WebhookVerificationException("Invalid Signature Headers"); } - if (timestamp < (now.AddSeconds(-1 * TOLERANCE_IN_SECONDS))) + if (timestamp < now.AddSeconds(-1 * TOLERANCE_IN_SECONDS)) { throw new WebhookVerificationException("Message timestamp too old"); } - if (timestamp > (now.AddSeconds(TOLERANCE_IN_SECONDS))) + if (timestamp > now.AddSeconds(TOLERANCE_IN_SECONDS)) { throw new WebhookVerificationException("Message timestamp too new"); } - return timestamp; + } + + public string Sign(ReadOnlySpan msgId, DateTimeOffset timestamp, ReadOnlySpan payload) + { + Span signature = stackalloc char[SIGNATURE_LENGTH_STRING]; + signature[0] = 'v'; + signature[1] = '1'; + signature[2] = ','; + CalculateSignature(msgId, timestamp.ToUnixTimeSeconds().ToString(), payload, signature.Slice(3), out var charsWritten); + return signature.Slice(0, charsWritten + 3).ToString(); } - public string Sign(string msgId, DateTimeOffset timestamp, string payload) + private void CalculateSignature( + ReadOnlySpan msgId, + ReadOnlySpan timestamp, + ReadOnlySpan payload, + Span signature, + out int charsWritten) + { + // Estimate buffer size and use stackalloc for smaller allocations + int msgIdLength = SafeUTF8Encoding.GetByteCount(msgId); + int payloadLength = SafeUTF8Encoding.GetByteCount(payload); + int timestampLength = SafeUTF8Encoding.GetByteCount(timestamp); + int totalLength = msgIdLength + 1 + timestampLength + 1 + payloadLength; + + Span toSignBytes = totalLength <= MAX_STACKALLOC + ? stackalloc byte[totalLength] + : new byte[totalLength]; + + SafeUTF8Encoding.GetBytes(msgId, toSignBytes.Slice(0, msgIdLength)); + toSignBytes[msgIdLength] = (byte)'.'; + SafeUTF8Encoding.GetBytes(timestamp, toSignBytes.Slice(msgIdLength + 1, timestampLength)); + toSignBytes[msgIdLength + 1 + timestampLength] = (byte)'.'; + SafeUTF8Encoding.GetBytes(payload, toSignBytes.Slice(msgIdLength + 1 + timestampLength + 1)); + + Span signatureBin = stackalloc byte[SIGNATURE_LENGTH_BYTES]; + CalculateSignature(toSignBytes, signatureBin); + + Span signatureB64 = stackalloc byte[SIGNATURE_LENGTH_BASE64]; + var result = Base64.EncodeToUtf8(signatureBin, signatureB64, out _, out var bytesWritten); + if (result != OperationStatus.Done) + throw new WebhookVerificationException("Failed to encode signature to base64"); + + if (!SafeUTF8Encoding.TryGetChars(signatureB64.Slice(0, bytesWritten), signature, out charsWritten)) + throw new WebhookVerificationException("Failed to convert signature to utf8"); + } + + private void CalculateSignature(ReadOnlySpan input, Span output) { - var toSign = $"{msgId}.{timestamp.ToUnixTimeSeconds().ToString()}.{payload}"; - var toSignBytes = SafeUTF8Encoding.GetBytes(toSign); - using (var hmac = new HMACSHA256(this.key)) + try + { + HMACSHA256.HashData(this.key, input, output); + } + catch (Exception) { - var hash = hmac.ComputeHash(toSignBytes); - var signature = Convert.ToBase64String(hash); - return $"v1,{signature}"; + throw new WebhookVerificationException("Output buffer too small"); } } }