diff --git a/README.md b/README.md index 0ad42a4..b0bb5e7 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,7 @@ In some situations, you may need to apply different security headers to differen ### 1. Configure your policies using `AddSecurityHeaderPolicies()` -You can configure named and default policies by calling `AddSecurityHeaderPolicies()` on `IServiceCollection`. You can configure the default policy to use, as well as any named policies. For example, the following configures the default policy (used when `UseSecurityHeaders()` is called without any arguments), and a named policy: +You can configure named and default policies by calling `AddSecurityHeaderPolicies()` on `IServiceCollection`. You can configure the default policy to use, as well as any named policies. For example, the following configures the default policy (used for all requests that are not customised for an endpoint), and a named policy: ```csharp var builder = WebApplication.CreateBuilder(); @@ -172,9 +172,9 @@ builder.Services.AddSecurityHeaderPolicies() ``` -### 2. Add the default middleware early to the pipeline +### 2. Call `UseSecurityHeaders()` early in the middleware pipeline -The security headers middleware can only add headers to _all_ requests if it is early in the middleware pipeline, so it's important to add the headders middleware at the start of your middleware pipeline. However, if you want to have endpoint-specific policies, then you _also_ need to place the middleware after the call to `UseRouting()`, as that is the point at which the endpoint that will be executed is selected. +The security headers middleware can only add headers to _all_ requests if it is early in the middleware pipeline, so it's important to add the headers middleware at the start of your middleware pipeline by calling `UseSecurityHeaders()`. However, if you want to have endpoint-specific policies, then you also need to call `UseEndpointSecurityHeaders()` _after_ the call to `UseRouting()`. For example: ```csharp var builder = WebApplication.CreateBuilder(); @@ -194,15 +194,13 @@ app.UseStaticFiles(); // other middleware app.UseAuthentication(); app.UseRouting(); -app.UseSecurityHeaders(); // 👈 Add after the routing middleware +app.UseEndpointSecurityHeaders(); // 👈 Add after the routing middleware app.UseAuthorization(); app.MapGet("/", () => "Hello world"); app.Run(); ``` -Note that if you pass a policy to any call to `UseSecurityHeaders()` it will override the "default" policy used at that point. - ### 3. Apply custom policies to endpoints To apply a non-default policy to an endpoint, use the `WithSecurityHeadersPolicy(policy)` endpoint extension method, and pass in the name of the policy to apply: @@ -243,14 +241,13 @@ public class HomeController : ControllerBase } ``` -Each call to `UseSecurityHeaders()` will re-evaluate the applicable policies; the headers are applied just before the response is sent. The policy to apply is determined as follows, with the first applicable policy selected. +Security headers are applied just before the response is sent. If you use the configuration described above, then the policy to apply is determined as follows, with the first applicable policy selected: 1. If an endpoint has been selected, and a named policy is applied, use that. -2. If a named or policy instance is passed to the `SecurityHeadersMiddleware`, use that. +2. If a named or policy instance is passed to the `SecurityHeadersMiddleware()`, use that. 3. If the default policy has been set using `SetDefaultPolicy()`, use that. 4. Otherwise, apply the default headers (those added by `AddDefaultSecurityHeaders()`) - ## RemoveServerHeader One point to be aware of is that the `RemoveServerHeader` method will rarely (ever?) be sufficient to remove the `Server` header from your output. If any subsequent middleware in your application pipeline add the header, then this will be able to remove it. However Kestrel will generally add the `Server` header too late in the pipeline to be able to modify it. diff --git a/src/NetEscapades.AspNetCore.SecurityHeaders/EndpointSecurityHeadersMiddleware.cs b/src/NetEscapades.AspNetCore.SecurityHeaders/EndpointSecurityHeadersMiddleware.cs new file mode 100644 index 0000000..d290e26 --- /dev/null +++ b/src/NetEscapades.AspNetCore.SecurityHeaders/EndpointSecurityHeadersMiddleware.cs @@ -0,0 +1,67 @@ +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using NetEscapades.AspNetCore.SecurityHeaders.Headers; +using NetEscapades.AspNetCore.SecurityHeaders.Infrastructure; + +namespace NetEscapades.AspNetCore.SecurityHeaders; + +/// +/// An ASP.NET Core middleware for adding security headers. +/// +internal class EndpointSecurityHeadersMiddleware +{ + private readonly RequestDelegate _next; + private readonly ILogger _logger; + private readonly CustomHeaderOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// The next middleware in the pipeline. + /// A logger for recording errors. + /// Options on how to control the settings that are applied + public EndpointSecurityHeadersMiddleware(RequestDelegate next, ILogger logger, CustomHeaderOptions options) + { + _next = next ?? throw new ArgumentNullException(nameof(next)); + _logger = logger; + _options = options; + } + + /// + /// Invoke the middleware + /// + /// The current context + /// A representing the asynchronous operation. + public Task Invoke(HttpContext context) + { + // Policy resolution rules: + // + // 1. If there is an endpoint with a named policy, then fetch that policy + // 2. Use the provided default policy + var endpoint = context.GetEndpoint(); + var metadata = endpoint?.Metadata.GetMetadata(); + + if (!string.IsNullOrEmpty(metadata?.PolicyName)) + { + if (_options.GetPolicy(metadata.PolicyName) is { } namedPolicy) + { + context.Items[SecurityHeadersMiddleware.HttpContextKey] = namedPolicy; + } + else + { + // log that we couldn't find the policy + _logger.LogWarning( + "Error configuring security headers middleware: policy '{PolicyName}' could not be found. " + + "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection " + + "and adding a policy with the required name.", + metadata.PolicyName); + } + } + + return _next(context); + } +} \ No newline at end of file diff --git a/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddleware.cs b/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddleware.cs index 331f12a..752952b 100644 --- a/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddleware.cs +++ b/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddleware.cs @@ -14,10 +14,12 @@ namespace NetEscapades.AspNetCore.SecurityHeaders; /// internal class SecurityHeadersMiddleware { - private const string HttpContextKey = "__NetEscapades.AspNetCore.SecurityHeaders"; + /// + /// The HttpContext key that tracks the policy to apply + /// + internal const string HttpContextKey = "__NetEscapades.AspNetCore.SecurityHeaders"; + private readonly RequestDelegate _next; - private readonly ILogger _logger; - private readonly CustomHeaderOptions? _options; private readonly HeaderPolicyCollection _defaultPolicy; private readonly NonceGenerator? _nonceGenerator; @@ -25,14 +27,10 @@ internal class SecurityHeadersMiddleware /// Initializes a new instance of the class. /// /// The next middleware in the pipeline. - /// A logger for recording errors. - /// Options on how to control the settings that are applied - /// A containing the policies to be applied. - public SecurityHeadersMiddleware(RequestDelegate next, ILogger logger, CustomHeaderOptions? options, HeaderPolicyCollection defaultPolicy) + /// A containing the policy to apply by default. + public SecurityHeadersMiddleware(RequestDelegate next, HeaderPolicyCollection defaultPolicy) { _next = next ?? throw new ArgumentNullException(nameof(next)); - _logger = logger; - _options = options; _defaultPolicy = defaultPolicy ?? throw new ArgumentNullException(nameof(defaultPolicy)); _nonceGenerator = MustGenerateNonce(_defaultPolicy) ? new() : null; } @@ -42,47 +40,17 @@ public SecurityHeadersMiddleware(RequestDelegate next, ILogger /// The current context /// A representing the asynchronous operation. - public async Task Invoke(HttpContext context) + public Task Invoke(HttpContext context) { - // Policy resolution rules: - // - // 1. If there is an endpoint with a named policy, then fetch that policy - // 2. Use the provided default policy - var endpoint = context.GetEndpoint(); - var metadata = endpoint?.Metadata.GetMetadata(); - - HeaderPolicyCollection policy = _defaultPolicy; - - if (!string.IsNullOrEmpty(metadata?.PolicyName)) - { - if (_options?.GetPolicy(metadata.PolicyName) is { } namedPolicy) - { - policy = namedPolicy; - } - else - { - // log that we couldn't find the policy - _logger.LogWarning( - "Error configuring security headers middleware: policy '{PolicyName}' could not be found. " - + "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection " - + "and adding a policy with the required name. Using default policy for request", - metadata.PolicyName); - } - } - - if (context.Items[HttpContextKey] is null) + // Write into the context, so that subsequent requests can "overwrite" it + context.Items[HttpContextKey] = _defaultPolicy; + context.Response.OnStarting(OnResponseStarting, context); + if (_nonceGenerator is not null) { - context.Response.OnStarting(OnResponseStarting, context); - if (_nonceGenerator is not null) - { - context.SetNonce(_nonceGenerator.GetNonce(Constants.DefaultBytesInNonce)); - } + context.SetNonce(_nonceGenerator.GetNonce(Constants.DefaultBytesInNonce)); } - // Write into the context, so that subsequent requests can "overwrite" it - context.Items[HttpContextKey] = policy; - - await _next(context); + return _next(context); } private static Task OnResponseStarting(object state) diff --git a/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddlewareExtensions.cs b/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddlewareExtensions.cs index 527c949..c5853ae 100644 --- a/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddlewareExtensions.cs +++ b/src/NetEscapades.AspNetCore.SecurityHeaders/SecurityHeadersMiddlewareExtensions.cs @@ -30,8 +30,7 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap throw new ArgumentNullException(nameof(policies)); } - var options = (CustomHeaderOptions)app.ApplicationServices.GetService(typeof(CustomHeaderOptions)); - return app.UseSecurityHeaders(options, policies); + return app.UseMiddleware(policies); } /// @@ -70,7 +69,7 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap var options = app.ApplicationServices.GetService(typeof(CustomHeaderOptions)) as CustomHeaderOptions; var policy = options?.DefaultPolicy ?? new HeaderPolicyCollection().AddDefaultSecurityHeaders(); - return app.UseSecurityHeaders(options, policy); + return app.UseSecurityHeaders(policy); } /// @@ -95,20 +94,39 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap var policy = options?.GetPolicy(policyName); if (policy is null) { - var log = ((ILoggerFactory)app.ApplicationServices.GetRequiredService(typeof(ILoggerFactory))).CreateLogger(typeof(SecurityHeadersMiddlewareExtensions)); - log.LogWarning( - "Error configuring security headers middleware: policy '{PolicyName}' could not be found. " - + "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection " - + "and adding a policy with the required name.", - policyName); - return app; + throw new InvalidOperationException( + $"Error configuring security headers middleware: policy '{policyName}' could not be found. " + + "Configure the policies for your application by calling IServiceCollection.AddSecurityHeaderPolicies() " + + $"in your application startup code and add a namedpolicy called '{policyName}'"); } - return app.UseSecurityHeaders(options, policy); + return app.UseSecurityHeaders(policy); } - private static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder app, CustomHeaderOptions? options, HeaderPolicyCollection policies) + /// + /// Adds middleware to your web application pipeline which sets security headers on responses + /// based on the specific endpoint invoked. + /// + /// To apply policies to a specific endpoint, use + /// or apply to your MVC or Razor Page endpoints. + /// + /// The IApplicationBuilder passed to your Configure method. + /// The original app parameter + public static IApplicationBuilder UseEndpointSecurityHeaders(this IApplicationBuilder app) { - return app.UseMiddleware(options ?? new CustomHeaderOptions(), policies); + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (app.ApplicationServices.GetService(typeof(CustomHeaderOptions)) is not CustomHeaderOptions options) + { + throw new InvalidOperationException( + "Error configuring security headers middleware: Unable to find required services. " + + "Configure the policies for your application by calling IServiceCollection.AddSecurityHeaderPolicies() " + + "in your application startup code"); + } + + return app.UseMiddleware(options); } } \ No newline at end of file diff --git a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/CspBuilderTests.cs b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/CspBuilderTests.cs index 16b32ee..eca1f44 100644 --- a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/CspBuilderTests.cs +++ b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/CspBuilderTests.cs @@ -221,11 +221,9 @@ public void Build_AddSrciptSrc_WhenAddsNonce_ConstantValueThrowsInvalidOperation var result = builder.Build(); - result.Invoking(x => - { - var val = x.ConstantValue; - }) - .ShouldThrow(); + result.Invoking(x=>x.ConstantValue) + .Should() + .Throw(); } [Theory] [InlineData(true, false)] @@ -272,11 +270,9 @@ public void Build_AddSrciptSrc_WhenDoesntAddNonce_BuilderThrowsInvalidOperation( var result = builder.Build(); - result.Invoking(x => - { - var val = x.Builder; - }) - .ShouldThrow(); + result.Invoking(x=>x.Builder) + .Should() + .Throw(); } [Fact] diff --git a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpSecurityHeadersMiddlewareFunctionalTests.cs b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpSecurityHeadersMiddlewareFunctionalTests.cs index 6b7f6b9..523c574 100644 --- a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpSecurityHeadersMiddlewareFunctionalTests.cs +++ b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpSecurityHeadersMiddlewareFunctionalTests.cs @@ -58,7 +58,7 @@ public async Task WhenUsingEndpoint_Overrides_Default(string path, string expect content.Should().Be(expected); // no security headers - response.Headers.Should().NotContain("X-Frame-Options"); + response.Headers.Should().NotContainKey("X-Frame-Options"); response.Headers.TryGetValues("Custom-Header", out var customHeader).Should().BeTrue(); customHeader.Should().ContainSingle("MyValue"); } diff --git a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpsSecurityHeadersMiddlewareFunctionalTests.cs b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpsSecurityHeadersMiddlewareFunctionalTests.cs index fb89769..a17f47e 100644 --- a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpsSecurityHeadersMiddlewareFunctionalTests.cs +++ b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/HttpsSecurityHeadersMiddlewareFunctionalTests.cs @@ -59,7 +59,7 @@ public async Task WhenUsingEndpoint_Overrides_Default(string path, string expect content.Should().Be(expected); // no security headers - response.Headers.Should().NotContain("X-Frame-Options"); + response.Headers.Should().NotContainKey("X-Frame-Options"); response.Headers.TryGetValues("Custom-Header", out var customHeader).Should().BeTrue(); customHeader.Should().ContainSingle("MyValue"); } diff --git a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/NetEscapades.AspNetCore.SecurityHeaders.Test.csproj b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/NetEscapades.AspNetCore.SecurityHeaders.Test.csproj index f66da14..9360009 100644 --- a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/NetEscapades.AspNetCore.SecurityHeaders.Test.csproj +++ b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/NetEscapades.AspNetCore.SecurityHeaders.Test.csproj @@ -14,7 +14,7 @@ - + diff --git a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/SecurityHeadersMiddlewareTests.cs b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/SecurityHeadersMiddlewareTests.cs index 06d65b0..4db07fb 100644 --- a/test/NetEscapades.AspNetCore.SecurityHeaders.Test/SecurityHeadersMiddlewareTests.cs +++ b/test/NetEscapades.AspNetCore.SecurityHeaders.Test/SecurityHeadersMiddlewareTests.cs @@ -113,7 +113,7 @@ public async Task HttpRequest_WithDefaultSecurityHeaders_WithNamedPolicy_SetsSec // Arrange var hostBuilder = new WebHostBuilder() .ConfigureServices(s => s.AddSecurityHeaderPolicies() - .AddPolicy(policyName, p => p.AddDefaultSecurityHeaders())) + .AddPolicy(policyName, p => p.AddCustomHeader("Custom-Header", "MyValue"))) .Configure(app => { app.UseSecurityHeaders(policyName); @@ -135,12 +135,13 @@ public async Task HttpRequest_WithDefaultSecurityHeaders_WithNamedPolicy_SetsSec response.EnsureSuccessStatusCode(); (await response.Content.ReadAsStringAsync()).Should().Be("Test response"); - response.Headers.AssertHttpRequestDefaultSecurityHeaders(); + response.Headers.Should().NotContainKey("X-Frame-Options"); + response.Headers.Should().ContainKey("Custom-Header").WhoseValue.Should().ContainSingle("MyValue"); } } [Fact] - public async Task HttpRequest_WithDefaultSecurityHeaders_WithUnknownNamedPolicy_DoesNotSetHeaders() + public void HttpRequest_WithDefaultSecurityHeaders_WithUnknownNamedPolicy_ThrowsException() { // Arrange var hostBuilder = new WebHostBuilder() @@ -154,6 +155,35 @@ public async Task HttpRequest_WithDefaultSecurityHeaders_WithUnknownNamedPolicy_ }); }); + Func act = () => new TestServer(hostBuilder); + act.Should().Throw(); + } + + [Fact] + public async Task HttpRequest_WithEndpointSecurityHeaders_WhenNoEndpoints_SetsDefaultHeaders() + { + var policyName = "custom"; + + // Arrange + var hostBuilder = new WebHostBuilder() + .ConfigureServices(s => + { + s.AddRouting(); + s.AddSecurityHeaderPolicies() + .AddPolicy(policyName, p => p.AddCustomHeader("Custom-Header", "MyValue")); + }) + .Configure(app => + { + app.UseSecurityHeaders(); + app.UseRouting(); + app.UseEndpointSecurityHeaders(); + app.Run(async context => + { + context.Response.ContentType = "text/html"; + await context.Response.WriteAsync("Test response"); + }); + }); + using (var server = new TestServer(hostBuilder)) { // Act @@ -165,10 +195,97 @@ public async Task HttpRequest_WithDefaultSecurityHeaders_WithUnknownNamedPolicy_ response.EnsureSuccessStatusCode(); (await response.Content.ReadAsStringAsync()).Should().Be("Test response"); - response.Headers.TryGetValues("X-Frame-Options", out _).Should().BeFalse(); + response.Headers.AssertHttpRequestDefaultSecurityHeaders(); } } - + + [Fact] + public async Task HttpRequest_WithEndpointSecurityHeaders_WhenEndpointsHasNoMetadata_SetsDefaultHeaders() + { + var policyName = "custom"; + + // Arrange + var hostBuilder = new WebHostBuilder() + .ConfigureServices(s => + { + s.AddRouting(); + s.AddSecurityHeaderPolicies() + .AddPolicy(policyName, p => p.AddCustomHeader("Custom-Header", "MyValue")); + }) + .Configure(app => + { + app.UseSecurityHeaders(); + app.UseRouting(); + app.UseEndpointSecurityHeaders(); + app.UseEndpoints(endpoints => + { + endpoints.MapPut("/", async context => + { + context.Response.ContentType = "text/html"; + await context.Response.WriteAsync("Test response"); + }); + }); + }); + + using (var server = new TestServer(hostBuilder)) + { + // Act + // Actual request. + var response = await server.CreateRequest("/") + .SendAsync("PUT"); + + // Assert + response.EnsureSuccessStatusCode(); + + (await response.Content.ReadAsStringAsync()).Should().Be("Test response"); + response.Headers.AssertHttpRequestDefaultSecurityHeaders(); + } + } + + [Fact] + public async Task HttpRequest_WithEndpointSecurityHeaders_WhenEndpointsHasMetadata_SetsCustomHeaders() + { + var policyName = "custom"; + + // Arrange + var hostBuilder = new WebHostBuilder() + .ConfigureServices(s => + { + s.AddRouting(); + s.AddSecurityHeaderPolicies() + .AddPolicy(policyName, p => p.AddCustomHeader("Custom-Header", "MyValue")); + }) + .Configure(app => + { + app.UseSecurityHeaders(); + app.UseRouting(); + app.UseEndpointSecurityHeaders(); + app.UseEndpoints(endpoints => + { + endpoints.MapPut("/", async context => + { + context.Response.ContentType = "text/html"; + await context.Response.WriteAsync("Test response"); + }).WithSecurityHeadersPolicy(policyName); + }); + }); + + using (var server = new TestServer(hostBuilder)) + { + // Act + // Actual request. + var response = await server.CreateRequest("/") + .SendAsync("PUT"); + + // Assert + response.EnsureSuccessStatusCode(); + + (await response.Content.ReadAsStringAsync()).Should().Be("Test response"); + response.Headers.Should().NotContainKey("X-Frame-Options"); + response.Headers.Should().ContainKey("Custom-Header").WhoseValue.Should().ContainSingle("MyValue"); + } + } + [Fact] public async Task HttpRequest_WithDefaultSecurityHeaders_WithConfiguredDefaultPolicy_SetsCustomHeaders() { diff --git a/test/SecurityHeadersMiddlewareWebSite/Startup.cs b/test/SecurityHeadersMiddlewareWebSite/Startup.cs index 6312e21..ea12fc7 100644 --- a/test/SecurityHeadersMiddlewareWebSite/Startup.cs +++ b/test/SecurityHeadersMiddlewareWebSite/Startup.cs @@ -23,7 +23,7 @@ public void Configure(IApplicationBuilder app) { app.UseSecurityHeaders(); app.UseRouting(); - app.UseSecurityHeaders(); + app.UseEndpointSecurityHeaders(); app.UseEndpoints(endpoints => { endpoints.MapGet("/custom", context => context.Response.WriteAsync("Hello World!"))