Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

CORS rework to mix with hosting cors provider (if present) #245

Merged
merged 5 commits into from
Aug 29, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/IdentityServer4/Configuration/DependencyInjection/Decorator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using System;

namespace IdentityServer4.Configuration.DependencyInjection
{
public class Decorator<TService>
{
public TService Instance { get; set; }

public Decorator(TService instance)
{
Instance = instance;
}
}

public class Decorator<TService, TImpl> : Decorator<TService>
where TImpl : class, TService
{
public Decorator(TImpl instance) : base(instance)
{
}
}

public class DisposableDecorator<TService> : Decorator<TService>, IDisposable
{
public DisposableDecorator(TService instance) : base(instance)
{
}

public void Dispose()
{
(Instance as IDisposable)?.Dispose();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using IdentityServer4;
using IdentityServer4.Configuration;
using IdentityServer4.Configuration.DependencyInjection;
using IdentityServer4.Endpoints;
using IdentityServer4.Endpoints.Results;
using IdentityServer4.Events;
Expand All @@ -22,6 +23,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using System;
using System.Linq;

namespace Microsoft.Extensions.DependencyInjection
{
Expand Down Expand Up @@ -125,13 +127,7 @@ public static IIdentityServerBuilder AddCoreServices(this IIdentityServerBuilder
builder.Services.AddScoped<AuthenticationHandler>();

builder.Services.AddCors();
builder.Services.AddTransient<ICorsPolicyProvider>(provider =>
{
return new PolicyProvider(
provider.GetRequiredService<ILogger<PolicyProvider>>(),
Constants.ProtocolRoutePaths.CorsPaths,
provider.GetRequiredService<ICorsPolicyService>());
});
builder.Services.AddTransientDecorator<ICorsPolicyProvider, PolicyProvider>();

return builder;
}
Expand Down Expand Up @@ -204,5 +200,50 @@ public static IIdentityServerBuilder AddInMemoryCaching(this IIdentityServerBuil

return builder;
}

static void AddTransientDecorator<TService, TImplementation>(this IServiceCollection services)
where TService : class
where TImplementation : class, TService
{
services.AddDecorator<TService>();
services.AddTransient<TService, TImplementation>();
}

static void AddDecorator<TService>(this IServiceCollection services)
{
var registration = services.FirstOrDefault(x => x.ServiceType == typeof(TService));
if (registration == null)
{
throw new InvalidOperationException("Service type: " + typeof(TService).Name + " not registered.");
}
if (services.Any(x => x.ServiceType == typeof(Decorator<TService>)))
{
throw new InvalidOperationException("Decorator already registered for type: " + typeof(TService).Name + ".");
}

services.Remove(registration);

if (registration.ImplementationInstance != null)
{
var type = registration.ImplementationInstance.GetType();
var innerType = typeof(Decorator<,>).MakeGenericType(typeof(TService), type);
services.Add(new ServiceDescriptor(typeof(Decorator<TService>), innerType, ServiceLifetime.Transient));
services.Add(new ServiceDescriptor(type, registration.ImplementationInstance));
}
else if (registration.ImplementationFactory != null)
{
services.Add(new ServiceDescriptor(typeof(Decorator<TService>), provider =>
{
return new DisposableDecorator<TService>((TService)registration.ImplementationFactory(provider));
}, registration.Lifetime));
}
else
{
var type = registration.ImplementationType;
var innerType = typeof(Decorator<,>).MakeGenericType(typeof(TService), registration.ImplementationType);
services.Add(new ServiceDescriptor(typeof(Decorator<TService>), innerType, ServiceLifetime.Transient));
services.Add(new ServiceDescriptor(type, type, registration.Lifetime));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using IdentityServer4.Extensions;
using Microsoft.AspNetCore.Http;
using System.Collections.Generic;
using System.Linq;

namespace IdentityServer4.Configuration
{
public class CorsOptions
{
/// <summary>
/// Gets or sets the name of the cors policy.
/// </summary>
/// <value>
/// The name of the cors policy.
/// </value>
public string CorsPolicyName { get; set; } = Constants.IdentityServerName;

/// <summary>
/// Gets or sets the cors paths.
/// </summary>
/// <value>
/// The cors paths.
/// </value>
public ICollection<PathString> CorsPaths { get; set; } = Constants.ProtocolRoutePaths.CorsPaths.Select(x => new PathString(x.EnsureLeadingSlash())).ToList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,13 @@ public class IdentityServerOptions
/// The caching options.
/// </value>
public CachingOptions CachingOptions { get; set; } = new CachingOptions();

/// <summary>
/// Gets or sets the cors options.
/// </summary>
/// <value>
/// The cors options.
/// </value>
public CorsOptions CorsOptions { get; set; } = new CorsOptions();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using IdentityServer4.Configuration;
using IdentityServer4.Hosting;
using Microsoft.Extensions.DependencyInjection;
using System;

namespace Microsoft.AspNetCore.Builder
Expand All @@ -10,7 +12,9 @@ public static class IdentityServerApplicationBuilderExtensions
{
public static IApplicationBuilder UseIdentityServer(this IApplicationBuilder app)
{
app.UseCors(String.Empty);
var options = app.ApplicationServices.GetRequiredService<IdentityServerOptions>();
app.UseCors(options.CorsOptions.CorsPolicyName);

app.ConfigureCookies();
app.UseMiddleware<AuthenticationMiddleware>();
app.UseMiddleware<BaseUrlMiddleware>();
Expand Down
4 changes: 1 addition & 3 deletions src/IdentityServer4/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ public static class ProtocolRoutePaths
DiscoveryWebKeys,
Token,
UserInfo,
// TODO
//IdentityTokenValidation,
//Revocation
Revocation
};
}

Expand Down
57 changes: 24 additions & 33 deletions src/IdentityServer4/Hosting/Cors/PolicyProvider.cs
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using IdentityServer4.Services;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Cors.Infrastructure;
using IdentityServer4.Configuration;
using IdentityServer4.Configuration.DependencyInjection;

namespace IdentityServer4.Hosting.Cors
{
public class PolicyProvider : ICorsPolicyProvider
{
private readonly ICorsPolicyService _corsPolicyService;
private readonly string[] _allowedPaths;
private readonly ILogger<PolicyProvider> _logger;
private readonly ICorsPolicyProvider _inner;
private readonly IdentityServerOptions _options;

public PolicyProvider(
ILogger<PolicyProvider> logger,
IEnumerable<string> allowedPaths,
Decorator<ICorsPolicyProvider> inner,
IdentityServerOptions options,
ICorsPolicyService corsPolicyService)
{
if (allowedPaths == null) throw new ArgumentNullException("allowedPaths");

_logger = logger;
_allowedPaths = allowedPaths.Select(Normalize).ToArray();
_inner = inner.Instance;
_options = options;
_corsPolicyService = corsPolicyService;
}

public async Task<CorsPolicy> GetPolicyAsync(HttpContext context, string policyName)
public Task<CorsPolicy> GetPolicyAsync(HttpContext context, string policyName)
{
if (_options.CorsOptions.CorsPolicyName == policyName)
{
return ProcessAsync(context);
}
else
{
return _inner.GetPolicyAsync(context, policyName);
}
}

async Task<CorsPolicy> ProcessAsync(HttpContext context)
{
var path = context.Request.Path.ToString();
var origin = context.Request.Headers["Origin"].First();
var thisOrigin = context.Request.Scheme + "://" + context.Request.Host;

Expand All @@ -42,6 +54,7 @@ public async Task<CorsPolicy> GetPolicyAsync(HttpContext context, string policyN
// todo: do we still need this check?
if (origin != null && origin != thisOrigin)
{
var path = context.Request.Path;
if (IsPathAllowed(path))
{
_logger.LogInformation("CORS request made for path: {0} from origin: {1}", path, origin);
Expand Down Expand Up @@ -78,31 +91,9 @@ private CorsPolicy Allow(string origin)
return policy;
}

private bool IsPathAllowed(string pathToCheck)
{
var requestPath = Normalize(pathToCheck);
return _allowedPaths.Any(path => requestPath.Equals(path, StringComparison.OrdinalIgnoreCase));
}

private string Normalize(string path)
private bool IsPathAllowed(PathString path)
{
if (String.IsNullOrWhiteSpace(path) || path == "/")
{
path = "/";
}
else
{
if (!path.StartsWith("/"))
{
path = "/" + path;
}
if (path.EndsWith("/"))
{
path = path.Substring(0, path.Length - 1);
}
}

return path;
return _options.CorsOptions.CorsPaths.Any(x => path == x);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Http;
using System.Threading.Tasks;

namespace IdentityServer4.UnitTests.Hosting.Cors
{
public class MockCorsPolicyProvider : ICorsPolicyProvider
{
public bool WasCalled { get; set; }
public CorsPolicy Response { get; set; }

public Task<CorsPolicy> GetPolicyAsync(HttpContext context, string policyName)
{
WasCalled = true;
return Task.FromResult(Response);
}
}
}
Loading