From 42573aef856f9a0a5fb937d3ce23ee7d859266cb Mon Sep 17 00:00:00 2001 From: Niels Swimberghe <3382717+Swimburger@users.noreply.github.com> Date: Tue, 27 Dec 2022 14:56:37 -0500 Subject: [PATCH] Make Request Validation async to load the form async --- .../RequestValidationHelperTests.cs | 64 ++++++++++-------- .../TwilioClientTests.cs | 6 +- .../ValidateRequestAttributeTests.cs | 51 ++++++++------ .../ValidateTwilioRequestFilterTests.cs | 8 ++- .../RequestValidationHelper.cs | 66 +++++++++++++++++-- .../ValidateRequestAttribute.cs | 30 +++++---- .../ValidateTwilioRequestFilter.cs | 18 ++--- .../ValidateTwilioRequestMiddleware.cs | 2 +- 8 files changed, 165 insertions(+), 80 deletions(-) diff --git a/src/Twilio.AspNet.Core.UnitTests/RequestValidationHelperTests.cs b/src/Twilio.AspNet.Core.UnitTests/RequestValidationHelperTests.cs index 3be4940..29c6385 100644 --- a/src/Twilio.AspNet.Core.UnitTests/RequestValidationHelperTests.cs +++ b/src/Twilio.AspNet.Core.UnitTests/RequestValidationHelperTests.cs @@ -4,6 +4,8 @@ using System.Net; using System.Security.Cryptography; using System.Text; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; using Xunit; @@ -15,7 +17,8 @@ public class ContextMocks public Moq.Mock HttpContext { get; set; } public Moq.Mock Request { get; set; } - public ContextMocks(bool isLocal, FormCollection form = null, bool isProxied = false) : this("", isLocal, form, isProxied) + public ContextMocks(bool isLocal, FormCollection form = null, bool isProxied = false) : this("", isLocal, form, + isProxied) { } @@ -48,6 +51,8 @@ public ContextMocks(string urlOverride, bool isLocal, FormCollection form = null { Request.Setup(x => x.Method).Returns("POST"); Request.Setup(x => x.Form).Returns(form); + Request.Setup(x => x.ReadFormAsync(new CancellationToken())) + .Returns(() => Task.FromResult((IFormCollection)form)); Request.Setup(x => x.HasFormContentType).Returns(true); } } @@ -80,81 +85,84 @@ private string CalculateSignature(string urlOverride, FormCollection form) public class RequestValidationHelperTests { [Fact] - public void TestLocal() + public async Task TestLocal() { - var fakeContext = (new ContextMocks(true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, "bad-token", true); + var fakeContext = new ContextMocks(true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, "bad-token", true); Assert.True(result); } [Fact] - public void TestNoLocalDueToProxy() + public async Task TestNoLocalDueToProxy() { - var fakeContext = (new ContextMocks(true, isProxied: true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, "bad-token", true); + var fakeContext = new ContextMocks(true, isProxied: true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, "bad-token", true); Assert.False(result); } [Fact] - public void TestNoLocal() + public async Task TestNoLocal() { - var fakeContext = (new ContextMocks(true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, "bad-token", false); + var fakeContext = new ContextMocks(true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, "bad-token", false); Assert.False(result); } [Fact] - public void TestNoForm() + public async Task TestNoForm() { - var fakeContext = (new ContextMocks(true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, ContextMocks.fakeAuthToken, false); + var fakeContext = new ContextMocks(true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, ContextMocks.fakeAuthToken, false); Assert.True(result); } [Fact] - public void TestBadForm() + public async Task TestBadForm() { var contextMocks = new ContextMocks(true); var fakeContext = contextMocks.HttpContext.Object; contextMocks.Request.Setup(x => x.Method).Returns("POST"); contextMocks.Request.Setup(x => x.Form).Throws(new Exception("poof!")); - var result = RequestValidationHelper.IsValidRequest(fakeContext, ContextMocks.fakeAuthToken, false); + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, ContextMocks.fakeAuthToken, false); Assert.True(result); } [Fact] - public void TestUrlOverrideFail() + public async Task TestUrlOverrideFail() { - var fakeContext = (new ContextMocks(true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, ContextMocks.fakeAuthToken, "https://example.com/", false); + var fakeContext = new ContextMocks(true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, ContextMocks.fakeAuthToken, + "https://example.com/", false); Assert.False(result); } [Fact] - public void TestUrlOverride() + public async Task TestUrlOverride() { - var fakeContext = (new ContextMocks("https://example.com/", true)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, ContextMocks.fakeAuthToken, "https://example.com/", false); + var fakeContext = new ContextMocks("https://example.com/", true).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, ContextMocks.fakeAuthToken, + "https://example.com/", false); Assert.True(result); } [Fact] - public void TestForm() + public async Task TestForm() { - var form = new FormCollection(new Dictionary() { - {"key1", "value1"}, - {"key2", "value2"} - }); - var fakeContext = (new ContextMocks(true, form)).HttpContext.Object; - var result = RequestValidationHelper.IsValidRequest(fakeContext, ContextMocks.fakeAuthToken, false); + var form = new FormCollection(new Dictionary + { + {"key1", "value1"}, + {"key2", "value2"} + }); + var fakeContext = new ContextMocks(true, form).HttpContext.Object; + var result = await RequestValidationHelper.IsValidRequestAsync(fakeContext, ContextMocks.fakeAuthToken, false); Assert.True(result); } diff --git a/src/Twilio.AspNet.Core.UnitTests/TwilioClientTests.cs b/src/Twilio.AspNet.Core.UnitTests/TwilioClientTests.cs index a2c3571..6fb17b4 100644 --- a/src/Twilio.AspNet.Core.UnitTests/TwilioClientTests.cs +++ b/src/Twilio.AspNet.Core.UnitTests/TwilioClientTests.cs @@ -18,7 +18,7 @@ public class TwilioClientTests private static readonly TwilioOptions ValidTwilioOptions = new() { AuthToken = "My Twilio:AuthToken", - Client = new TwilioClientOptions() + Client = new TwilioClientOptions { AccountSid = "MyAccountSid!", AuthToken = "My Twilio:Client:AuthToken", @@ -32,7 +32,7 @@ public class TwilioClientTests private static readonly TwilioOptions AuthTokenTwilioOptions = new() { - Client = new TwilioClientOptions() + Client = new TwilioClientOptions { AccountSid = "MyAccountSid!", AuthToken = "My Twilio:Client:AuthToken", @@ -44,7 +44,7 @@ public class TwilioClientTests private static readonly TwilioOptions ApiKeyTwilioOptions = new() { - Client = new TwilioClientOptions() + Client = new TwilioClientOptions { AccountSid = "MyAccountSid!", ApiKeySid = "My API Key SID", diff --git a/src/Twilio.AspNet.Core.UnitTests/ValidateRequestAttributeTests.cs b/src/Twilio.AspNet.Core.UnitTests/ValidateRequestAttributeTests.cs index 774cbad..8b52b19 100644 --- a/src/Twilio.AspNet.Core.UnitTests/ValidateRequestAttributeTests.cs +++ b/src/Twilio.AspNet.Core.UnitTests/ValidateRequestAttributeTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Net; using System.Text.Json; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Abstractions; @@ -20,7 +21,7 @@ public class ValidateRequestAttributeTests private static readonly TwilioOptions ValidTwilioOptions = new() { AuthToken = "My Twilio:AuthToken", - RequestValidation = new TwilioRequestValidationOptions() + RequestValidation = new TwilioRequestValidationOptions { AuthToken = "My Twilio:RequestValidation:AuthToken", AllowLocal = false, @@ -85,7 +86,7 @@ public void AddTwilio_Should_Configure_ValidateRequestAttribute() var attributeFactory = new ValidateRequestAttribute(); var attribute = - (ValidateRequestAttribute.InternalValidateRequestAttribute) attributeFactory + (ValidateRequestAttribute.InternalValidateRequestFilter) attributeFactory .CreateInstance(serviceProvider); Assert.Equal(ValidTwilioOptions.RequestValidation.AllowLocal, attribute.AllowLocal); @@ -101,22 +102,22 @@ public void Creating_ValidateRequestAttribute_Without_AddTwilioClient_Should_Thr var attributeFactory = new ValidateRequestAttribute(); var exception = Assert.Throws(() => - (ValidateRequestAttribute.InternalValidateRequestAttribute) attributeFactory + (ValidateRequestAttribute.InternalValidateRequestFilter) attributeFactory .CreateInstance(serviceProvider)); Assert.Equal("RequestValidationOptions is not configured.", exception.Message); } [Fact] - public void ValidateRequestAttribute_Validates_Request_Successfully() + public async Task ValidateRequestAttribute_Validates_Request_Successfully() { - var attribute = new ValidateRequestAttribute.InternalValidateRequestAttribute( - authToken: ContextMocks.fakeAuthToken, - null, + var attribute = new ValidateRequestAttribute.InternalValidateRequestFilter( + authToken: ContextMocks.fakeAuthToken, + null, false ); - var form = new FormCollection(new Dictionary() + var form = new FormCollection(new Dictionary { {"key1", "value1"}, {"key2", "value2"} @@ -129,22 +130,27 @@ public void ValidateRequestAttribute_Validates_Request_Successfully() new Dictionary(), new object() ); - - attribute.OnActionExecuting(actionExecutingContext); - + var actionExecutedContext = new ActionExecutedContext( + new ActionContext(fakeContext, new RouteData(), new ActionDescriptor()), + new List(), + new object() + ); + + await attribute.OnActionExecutionAsync(actionExecutingContext, () => Task.FromResult(actionExecutedContext)); + Assert.Null(actionExecutingContext.Result); } [Fact] - public void ValidateRequestFilter_Validates_Request_Forbid() + public async Task ValidateRequestFilter_Validates_Request_Forbid() { - var attribute = new ValidateRequestAttribute.InternalValidateRequestAttribute( - authToken: "bad", - null, + var attribute = new ValidateRequestAttribute.InternalValidateRequestFilter( + authToken: "bad", + null, false ); - var form = new FormCollection(new Dictionary() + var form = new FormCollection(new Dictionary { {"key1", "value1"}, {"key2", "value2"} @@ -157,11 +163,16 @@ public void ValidateRequestFilter_Validates_Request_Forbid() new Dictionary(), new object() ); - - attribute.OnActionExecuting(actionExecutingContext); + var actionExecutedContext = new ActionExecutedContext( + new ActionContext(fakeContext, new RouteData(), new ActionDescriptor()), + new List(), + new object() + ); + + await attribute.OnActionExecutionAsync(actionExecutingContext, () => Task.FromResult(actionExecutedContext)); - var statusCodeResult = (StatusCodeResult)actionExecutingContext.Result!; + var statusCodeResult = (StatusCodeResult) actionExecutingContext.Result!; Assert.NotNull(statusCodeResult); - Assert.Equal((int)HttpStatusCode.Forbidden, statusCodeResult.StatusCode); + Assert.Equal((int) HttpStatusCode.Forbidden, statusCodeResult.StatusCode); } } \ No newline at end of file diff --git a/src/Twilio.AspNet.Core.UnitTests/ValidateTwilioRequestFilterTests.cs b/src/Twilio.AspNet.Core.UnitTests/ValidateTwilioRequestFilterTests.cs index bfd2279..7098809 100644 --- a/src/Twilio.AspNet.Core.UnitTests/ValidateTwilioRequestFilterTests.cs +++ b/src/Twilio.AspNet.Core.UnitTests/ValidateTwilioRequestFilterTests.cs @@ -15,7 +15,7 @@ public class ValidateTwilioRequestFilterTests private static readonly TwilioOptions ValidTwilioOptions = new() { AuthToken = "My Twilio:AuthToken", - RequestValidation = new TwilioRequestValidationOptions() + RequestValidation = new TwilioRequestValidationOptions { AuthToken = "My Twilio:RequestValidation:AuthToken", AllowLocal = false, @@ -67,7 +67,8 @@ public async Task ValidateRequestFilter_Validates_Request_Successfully() var serviceProvider = serviceCollection.BuildServiceProvider(); var filter = serviceProvider.GetRequiredService(); - var form = new FormCollection(new Dictionary() { + var form = new FormCollection(new Dictionary + { {"key1", "value1"}, {"key2", "value2"} }); @@ -95,7 +96,8 @@ public async Task ValidateRequestFilter_Validates_Request_Forbid() var serviceProvider = serviceCollection.BuildServiceProvider(); var filter = serviceProvider.GetRequiredService(); - var form = new FormCollection(new Dictionary() { + var form = new FormCollection(new Dictionary + { {"key1", "value1"}, {"key2", "value2"} }); diff --git a/src/Twilio.AspNet.Core/RequestValidationHelper.cs b/src/Twilio.AspNet.Core/RequestValidationHelper.cs index 192fd92..91bd05a 100644 --- a/src/Twilio.AspNet.Core/RequestValidationHelper.cs +++ b/src/Twilio.AspNet.Core/RequestValidationHelper.cs @@ -1,5 +1,7 @@ -using System.Linq; +using System.Collections.Generic; +using System.Linq; using System.Net; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Twilio.Security; @@ -11,6 +13,60 @@ namespace Twilio.AspNet.Core /// public static class RequestValidationHelper { + /// + /// Performs request validation using the current HTTP context passed in manually or from + /// the ASP.NET MVC ValidateRequestAttribute + /// + /// HttpContext to use for validation + /// AuthToken for the account used to sign the request + /// Skip validation for local requests + public static Task IsValidRequestAsync(HttpContext context, string authToken, bool allowLocal = true) + => IsValidRequestAsync(context, authToken, null, allowLocal); + + /// + /// Performs request validation using the current HTTP context passed in manually or from + /// the ASP.NET MVC ValidateRequestAttribute + /// + /// HttpContext to use for validation + /// AuthToken for the account used to sign the request + /// The URL to use for validation, if different from Request.Url (sometimes needed if web site is behind a proxy or load-balancer) + /// Skip validation for local requests + public static async Task IsValidRequestAsync( + HttpContext context, + string authToken, + string urlOverride, + bool allowLocal = true + ) + { + var request = context.Request; + + if (allowLocal && IsLocal(request)) + { + return true; + } + + // validate request + // http://www.twilio.com/docs/security-reliability/security + // Take the full URL of the request, from the protocol (http...) through the end of the query string (everything after the ?) + string fullUrl = string.IsNullOrEmpty(urlOverride) + ? $"{request.Scheme}://{(request.IsHttps ? request.Host.Host : request.Host.ToUriComponent())}{request.Path}{request.QueryString}" + : urlOverride; + + Dictionary parameters = null; + if (request.HasFormContentType) + { + var form = await request.ReadFormAsync(context.RequestAborted).ConfigureAwait(false); + parameters = form.ToDictionary(kv => kv.Key, kv => kv.Value.ToString()); + } + + var validator = new RequestValidator(authToken); + return validator.Validate( + url: fullUrl, + parameters: parameters, + expected: request.Headers["X-Twilio-Signature"] + ); + } + /// /// Performs request validation using the current HTTP context passed in manually or from /// the ASP.NET MVC ValidateRequestAttribute @@ -50,9 +106,11 @@ public static bool IsValidRequest( ? $"{request.Scheme}://{(request.IsHttps ? request.Host.Host : request.Host.ToUriComponent())}{request.Path}{request.QueryString}" : urlOverride; - var parameters = request.HasFormContentType - ? request.Form.Keys.ToDictionary(k => k, k => request.Form[k].ToString()) - : null; + Dictionary parameters = null; + if (request.HasFormContentType) + { + parameters = request.Form.ToDictionary(kv => kv.Key, kv => kv.Value.ToString()); + } var validator = new RequestValidator(authToken); return validator.Validate( diff --git a/src/Twilio.AspNet.Core/ValidateRequestAttribute.cs b/src/Twilio.AspNet.Core/ValidateRequestAttribute.cs index a4189ce..881897e 100644 --- a/src/Twilio.AspNet.Core/ValidateRequestAttribute.cs +++ b/src/Twilio.AspNet.Core/ValidateRequestAttribute.cs @@ -1,5 +1,6 @@ using System; using System.Net; +using System.Threading.Tasks; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; @@ -26,18 +27,18 @@ public IFilterMetadata CreateInstance(IServiceProvider serviceProvider) var options = serviceProvider.GetService>()?.Value; if (options == null) throw new Exception("RequestValidationOptions is not configured."); - return new InternalValidateRequestAttribute( + return new InternalValidateRequestFilter( authToken: options.AuthToken, baseUrlOverride: options.BaseUrlOverride?.TrimEnd('/'), allowLocal: options.AllowLocal ?? true ); } - internal class InternalValidateRequestAttribute : ActionFilterAttribute + internal class InternalValidateRequestFilter : IAsyncActionFilter { - internal string AuthToken { get; set; } - internal string BaseUrlOverride { get; set; } - internal bool AllowLocal { get; set; } + internal string AuthToken { get; } + internal string BaseUrlOverride { get; } + internal bool AllowLocal { get; } /// /// Initializes a new instance of the ValidateRequestAttribute class. @@ -48,30 +49,33 @@ internal class InternalValidateRequestAttribute : ActionFilterAttribute /// if different from Request.Url (sometimes needed if web site is behind a proxy or load-balancer) /// /// Skip validation for local requests - public InternalValidateRequestAttribute(string authToken, string baseUrlOverride, bool allowLocal) + public InternalValidateRequestFilter(string authToken, string baseUrlOverride, bool allowLocal) { AuthToken = authToken; BaseUrlOverride = baseUrlOverride; AllowLocal = allowLocal; } - public override void OnActionExecuting(ActionExecutingContext filterContext) + public async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next) { - var httpContext = filterContext.HttpContext; + var httpContext = context.HttpContext; var request = httpContext.Request; string urlOverride = null; if (BaseUrlOverride != null) { urlOverride = $"{BaseUrlOverride}{request.Path}{request.QueryString}"; } - - if (!RequestValidationHelper.IsValidRequest(httpContext, AuthToken, urlOverride, AllowLocal)) + + var isValid = await RequestValidationHelper + .IsValidRequestAsync(httpContext, AuthToken, urlOverride, AllowLocal).ConfigureAwait(false); + if (!isValid) { - filterContext.Result = new StatusCodeResult((int)HttpStatusCode.Forbidden); + context.Result = new StatusCodeResult((int) HttpStatusCode.Forbidden); + return; } - base.OnActionExecuting(filterContext); + await next(); } } } -} +} \ No newline at end of file diff --git a/src/Twilio.AspNet.Core/ValidateTwilioRequestFilter.cs b/src/Twilio.AspNet.Core/ValidateTwilioRequestFilter.cs index fb01f8b..177c151 100644 --- a/src/Twilio.AspNet.Core/ValidateTwilioRequestFilter.cs +++ b/src/Twilio.AspNet.Core/ValidateTwilioRequestFilter.cs @@ -14,9 +14,9 @@ namespace Twilio.AspNet.Core; /// public class ValidateTwilioRequestFilter : IEndpointFilter { - internal string AuthToken { get; set; } - internal string BaseUrlOverride { get; set; } - internal bool AllowLocal { get; set; } + internal string AuthToken { get; } + internal string BaseUrlOverride { get; } + internal bool AllowLocal { get; } public ValidateTwilioRequestFilter(IServiceProvider serviceProvider) { @@ -36,17 +36,19 @@ EndpointFilterDelegate next var httpContext = efiContext.HttpContext; var request = httpContext.Request; string urlOverride = null; - if (BaseUrlOverride != null) + if (!string.IsNullOrEmpty(BaseUrlOverride)) { urlOverride = $"{BaseUrlOverride}{request.Path}{request.QueryString}"; } - if (RequestValidationHelper.IsValidRequest(httpContext, AuthToken, urlOverride, AllowLocal)) + var isValid = await RequestValidationHelper.IsValidRequestAsync(httpContext, AuthToken, urlOverride, AllowLocal) + .ConfigureAwait(false); + if (!isValid) { - return await next(efiContext); + return Results.StatusCode((int) HttpStatusCode.Forbidden); } - - return Results.StatusCode((int) HttpStatusCode.Forbidden); + + return await next(efiContext); } } diff --git a/src/Twilio.AspNet.Core/ValidateTwilioRequestMiddleware.cs b/src/Twilio.AspNet.Core/ValidateTwilioRequestMiddleware.cs index fce0aee..398cbeb 100644 --- a/src/Twilio.AspNet.Core/ValidateTwilioRequestMiddleware.cs +++ b/src/Twilio.AspNet.Core/ValidateTwilioRequestMiddleware.cs @@ -34,7 +34,7 @@ public async Task InvokeAsync(HttpContext context) urlOverride = $"{options.BaseUrlOverride.TrimEnd('/')}{request.Path}{request.QueryString}"; } - var isValid = RequestValidationHelper.IsValidRequest(context, options.AuthToken, urlOverride, options.AllowLocal ?? true); + var isValid = await RequestValidationHelper.IsValidRequestAsync(context, options.AuthToken, urlOverride, options.AllowLocal ?? true); if (!isValid) { context.Response.StatusCode = (int) HttpStatusCode.Forbidden;