Skip to content

Commit

Permalink
Endpoint security middleware (#173)
Browse files Browse the repository at this point in the history
* Update FluentAssertions

* Split out EndpointSecurityHeadersMiddleware from SecurityHeadersMiddleware
  • Loading branch information
andrewlock committed Sep 25, 2024
1 parent 7949a01 commit d1b0a6d
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 87 deletions.
15 changes: 6 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ In some situations, you may need to apply different security headers to differen

### 1. Configure your policies using `AddSecurityHeaderPolicies()`

You can configure named and default policies by calling `AddSecurityHeaderPolicies()` on `IServiceCollection`. You can configure the default policy to use, as well as any named policies. For example, the following configures the default policy (used when `UseSecurityHeaders()` is called without any arguments), and a named policy:
You can configure named and default policies by calling `AddSecurityHeaderPolicies()` on `IServiceCollection`. You can configure the default policy to use, as well as any named policies. For example, the following configures the default policy (used for all requests that are not customised for an endpoint), and a named policy:

```csharp
var builder = WebApplication.CreateBuilder();
Expand All @@ -172,9 +172,9 @@ builder.Services.AddSecurityHeaderPolicies()
```

### 2. Add the default middleware early to the pipeline
### 2. Call `UseSecurityHeaders()` early in the middleware pipeline

The security headers middleware can only add headers to _all_ requests if it is early in the middleware pipeline, so it's important to add the headders middleware at the start of your middleware pipeline. However, if you want to have endpoint-specific policies, then you _also_ need to place the middleware after the call to `UseRouting()`, as that is the point at which the endpoint that will be executed is selected.
The security headers middleware can only add headers to _all_ requests if it is early in the middleware pipeline, so it's important to add the headers middleware at the start of your middleware pipeline by calling `UseSecurityHeaders()`. However, if you want to have endpoint-specific policies, then you also need to call `UseEndpointSecurityHeaders()` _after_ the call to `UseRouting()`. For example:

```csharp
var builder = WebApplication.CreateBuilder();
Expand All @@ -194,15 +194,13 @@ app.UseStaticFiles(); // other middleware
app.UseAuthentication();
app.UseRouting();

app.UseSecurityHeaders(); // 👈 Add after the routing middleware
app.UseEndpointSecurityHeaders(); // 👈 Add after the routing middleware
app.UseAuthorization();

app.MapGet("/", () => "Hello world");
app.Run();
```

Note that if you pass a policy to any call to `UseSecurityHeaders()` it will override the "default" policy used at that point.

### 3. Apply custom policies to endpoints

To apply a non-default policy to an endpoint, use the `WithSecurityHeadersPolicy(policy)` endpoint extension method, and pass in the name of the policy to apply:
Expand Down Expand Up @@ -243,14 +241,13 @@ public class HomeController : ControllerBase
}
```

Each call to `UseSecurityHeaders()` will re-evaluate the applicable policies; the headers are applied just before the response is sent. The policy to apply is determined as follows, with the first applicable policy selected.
Security headers are applied just before the response is sent. If you use the configuration described above, then the policy to apply is determined as follows, with the first applicable policy selected:

1. If an endpoint has been selected, and a named policy is applied, use that.
2. If a named or policy instance is passed to the `SecurityHeadersMiddleware`, use that.
2. If a named or policy instance is passed to the `SecurityHeadersMiddleware()`, use that.
3. If the default policy has been set using `SetDefaultPolicy()`, use that.
4. Otherwise, apply the default headers (those added by `AddDefaultSecurityHeaders()`)


## RemoveServerHeader

One point to be aware of is that the `RemoveServerHeader` method will rarely (ever?) be sufficient to remove the `Server` header from your output. If any subsequent middleware in your application pipeline add the header, then this will be able to remove it. However Kestrel will generally add the `Server` header too late in the pipeline to be able to modify it.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using NetEscapades.AspNetCore.SecurityHeaders.Headers;
using NetEscapades.AspNetCore.SecurityHeaders.Infrastructure;

namespace NetEscapades.AspNetCore.SecurityHeaders;

/// <summary>
/// An ASP.NET Core middleware for adding security headers.
/// </summary>
internal class EndpointSecurityHeadersMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<EndpointSecurityHeadersMiddleware> _logger;
private readonly CustomHeaderOptions _options;

/// <summary>
/// Initializes a new instance of the <see cref="EndpointSecurityHeadersMiddleware"/> class.
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="logger">A logger for recording errors.</param>
/// <param name="options">Options on how to control the settings that are applied</param>
public EndpointSecurityHeadersMiddleware(RequestDelegate next, ILogger<EndpointSecurityHeadersMiddleware> logger, CustomHeaderOptions options)
{
_next = next ?? throw new ArgumentNullException(nameof(next));
_logger = logger;
_options = options;
}

/// <summary>
/// Invoke the middleware
/// </summary>
/// <param name="context">The current context</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
public Task Invoke(HttpContext context)
{
// Policy resolution rules:
//
// 1. If there is an endpoint with a named policy, then fetch that policy
// 2. Use the provided default policy
var endpoint = context.GetEndpoint();
var metadata = endpoint?.Metadata.GetMetadata<ISecurityHeadersPolicyMetadata>();

if (!string.IsNullOrEmpty(metadata?.PolicyName))
{
if (_options.GetPolicy(metadata.PolicyName) is { } namedPolicy)
{
context.Items[SecurityHeadersMiddleware.HttpContextKey] = namedPolicy;
}
else
{
// log that we couldn't find the policy
_logger.LogWarning(
"Error configuring security headers middleware: policy '{PolicyName}' could not be found. "
+ "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection "
+ "and adding a policy with the required name.",
metadata.PolicyName);
}
}

return _next(context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,23 @@ namespace NetEscapades.AspNetCore.SecurityHeaders;
/// </summary>
internal class SecurityHeadersMiddleware
{
private const string HttpContextKey = "__NetEscapades.AspNetCore.SecurityHeaders";
/// <summary>
/// The HttpContext key that tracks the policy to apply
/// </summary>
internal const string HttpContextKey = "__NetEscapades.AspNetCore.SecurityHeaders";

private readonly RequestDelegate _next;
private readonly ILogger<SecurityHeadersMiddleware> _logger;
private readonly CustomHeaderOptions? _options;
private readonly HeaderPolicyCollection _defaultPolicy;
private readonly NonceGenerator? _nonceGenerator;

/// <summary>
/// Initializes a new instance of the <see cref="SecurityHeadersMiddleware"/> class.
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="logger">A logger for recording errors.</param>
/// <param name="options">Options on how to control the settings that are applied</param>
/// <param name="defaultPolicy">A <see cref="HeaderPolicyCollection"/> containing the policies to be applied.</param>
public SecurityHeadersMiddleware(RequestDelegate next, ILogger<SecurityHeadersMiddleware> logger, CustomHeaderOptions? options, HeaderPolicyCollection defaultPolicy)
/// <param name="defaultPolicy">A <see cref="HeaderPolicyCollection"/> containing the policy to apply by default.</param>
public SecurityHeadersMiddleware(RequestDelegate next, HeaderPolicyCollection defaultPolicy)
{
_next = next ?? throw new ArgumentNullException(nameof(next));
_logger = logger;
_options = options;
_defaultPolicy = defaultPolicy ?? throw new ArgumentNullException(nameof(defaultPolicy));
_nonceGenerator = MustGenerateNonce(_defaultPolicy) ? new() : null;
}
Expand All @@ -42,47 +40,17 @@ public SecurityHeadersMiddleware(RequestDelegate next, ILogger<SecurityHeadersMi
/// </summary>
/// <param name="context">The current context</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
public async Task Invoke(HttpContext context)
public Task Invoke(HttpContext context)
{
// Policy resolution rules:
//
// 1. If there is an endpoint with a named policy, then fetch that policy
// 2. Use the provided default policy
var endpoint = context.GetEndpoint();
var metadata = endpoint?.Metadata.GetMetadata<ISecurityHeadersPolicyMetadata>();

HeaderPolicyCollection policy = _defaultPolicy;

if (!string.IsNullOrEmpty(metadata?.PolicyName))
{
if (_options?.GetPolicy(metadata.PolicyName) is { } namedPolicy)
{
policy = namedPolicy;
}
else
{
// log that we couldn't find the policy
_logger.LogWarning(
"Error configuring security headers middleware: policy '{PolicyName}' could not be found. "
+ "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection "
+ "and adding a policy with the required name. Using default policy for request",
metadata.PolicyName);
}
}

if (context.Items[HttpContextKey] is null)
// Write into the context, so that subsequent requests can "overwrite" it
context.Items[HttpContextKey] = _defaultPolicy;
context.Response.OnStarting(OnResponseStarting, context);
if (_nonceGenerator is not null)
{
context.Response.OnStarting(OnResponseStarting, context);
if (_nonceGenerator is not null)
{
context.SetNonce(_nonceGenerator.GetNonce(Constants.DefaultBytesInNonce));
}
context.SetNonce(_nonceGenerator.GetNonce(Constants.DefaultBytesInNonce));
}

// Write into the context, so that subsequent requests can "overwrite" it
context.Items[HttpContextKey] = policy;

await _next(context);
return _next(context);
}

private static Task OnResponseStarting(object state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap
throw new ArgumentNullException(nameof(policies));
}

var options = (CustomHeaderOptions)app.ApplicationServices.GetService(typeof(CustomHeaderOptions));
return app.UseSecurityHeaders(options, policies);
return app.UseMiddleware<SecurityHeadersMiddleware>(policies);
}

/// <summary>
Expand Down Expand Up @@ -70,7 +69,7 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap
var options = app.ApplicationServices.GetService(typeof(CustomHeaderOptions)) as CustomHeaderOptions;
var policy = options?.DefaultPolicy ?? new HeaderPolicyCollection().AddDefaultSecurityHeaders();

return app.UseSecurityHeaders(options, policy);
return app.UseSecurityHeaders(policy);
}

/// <summary>
Expand All @@ -95,20 +94,39 @@ public static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder ap
var policy = options?.GetPolicy(policyName);
if (policy is null)
{
var log = ((ILoggerFactory)app.ApplicationServices.GetRequiredService(typeof(ILoggerFactory))).CreateLogger(typeof(SecurityHeadersMiddlewareExtensions));
log.LogWarning(
"Error configuring security headers middleware: policy '{PolicyName}' could not be found. "
+ "Configure the policies for your application by calling AddSecurityHeaderPolicies() on IServiceCollection "
+ "and adding a policy with the required name.",
policyName);
return app;
throw new InvalidOperationException(
$"Error configuring security headers middleware: policy '{policyName}' could not be found. "
+ "Configure the policies for your application by calling IServiceCollection.AddSecurityHeaderPolicies() "
+ $"in your application startup code and add a namedpolicy called '{policyName}'");
}

return app.UseSecurityHeaders(options, policy);
return app.UseSecurityHeaders(policy);
}

private static IApplicationBuilder UseSecurityHeaders(this IApplicationBuilder app, CustomHeaderOptions? options, HeaderPolicyCollection policies)
/// <summary>
/// Adds middleware to your web application pipeline which sets security headers on responses
/// based on the specific endpoint invoked.
///
/// To apply policies to a specific endpoint, use <see cref="EndpointConventionBuilderExtensions.WithSecurityHeadersPolicy{TBuilder}"/>
/// or apply <see cref="SecurityHeadersPolicyAttribute"/> to your MVC or Razor Page endpoints.
/// </summary>
/// <param name="app">The IApplicationBuilder passed to your Configure method.</param>
/// <returns>The original app parameter</returns>
public static IApplicationBuilder UseEndpointSecurityHeaders(this IApplicationBuilder app)
{
return app.UseMiddleware<SecurityHeadersMiddleware>(options ?? new CustomHeaderOptions(), policies);
if (app == null)
{
throw new ArgumentNullException(nameof(app));
}

if (app.ApplicationServices.GetService(typeof(CustomHeaderOptions)) is not CustomHeaderOptions options)
{
throw new InvalidOperationException(
"Error configuring security headers middleware: Unable to find required services. "
+ "Configure the policies for your application by calling IServiceCollection.AddSecurityHeaderPolicies() "
+ "in your application startup code");
}

return app.UseMiddleware<EndpointSecurityHeadersMiddleware>(options);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,9 @@ public void Build_AddSrciptSrc_WhenAddsNonce_ConstantValueThrowsInvalidOperation

var result = builder.Build();

result.Invoking(x =>
{
var val = x.ConstantValue;
})
.ShouldThrow<InvalidOperationException>();
result.Invoking(x=>x.ConstantValue)
.Should()
.Throw<InvalidOperationException>();
}
[Theory]
[InlineData(true, false)]
Expand Down Expand Up @@ -272,11 +270,9 @@ public void Build_AddSrciptSrc_WhenDoesntAddNonce_BuilderThrowsInvalidOperation(

var result = builder.Build();

result.Invoking(x =>
{
var val = x.Builder;
})
.ShouldThrow<InvalidOperationException>();
result.Invoking(x=>x.Builder)
.Should()
.Throw<InvalidOperationException>();
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task WhenUsingEndpoint_Overrides_Default(string path, string expect
content.Should().Be(expected);

// no security headers
response.Headers.Should().NotContain("X-Frame-Options");
response.Headers.Should().NotContainKey("X-Frame-Options");
response.Headers.TryGetValues("Custom-Header", out var customHeader).Should().BeTrue();
customHeader.Should().ContainSingle("MyValue");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Task WhenUsingEndpoint_Overrides_Default(string path, string expect
content.Should().Be(expected);

// no security headers
response.Headers.Should().NotContain("X-Frame-Options");
response.Headers.Should().NotContainKey("X-Frame-Options");
response.Headers.TryGetValues("Custom-Header", out var customHeader).Should().BeTrue();
customHeader.Should().ContainSingle("MyValue");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<ItemGroup>
<PackageReference Include="PublicApiGenerator" Version="11.1.0" />
<PackageReference Include="Verify.Xunit" Version="18.4.0" />
<PackageReference Include="FluentAssertions" Version="4.19.4" />
<PackageReference Include="FluentAssertions" Version="6.12.0" />
<PackageReference Include="xunit" Version="2.4.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
Expand Down
Loading

0 comments on commit d1b0a6d

Please sign in to comment.