Skip to content

Commit

Permalink
fix: cors
Browse files Browse the repository at this point in the history
  • Loading branch information
izatop committed Aug 15, 2020
1 parent 70a60aa commit 5ac872b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 14 deletions.
25 changes: 25 additions & 0 deletions packages/test/src/test1.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import {Application, PathRoute} from "@typesafeunit/app";
import {Runtime} from "@typesafeunit/unit";
import {debugLogFormat, Logger, StdErrorTransport, StdOutTransport} from "@typesafeunit/util";
import {CorsValidation, Server} from "@typesafeunit/web";
import {BaseTestAction} from "./actions/BaseTestAction";
import {BaseContext} from "./context/BaseContext";

async function main() {
const app = await Application.factory(new BaseContext(), [
new PathRoute(BaseTestAction, {route: "GET /test", state: () => ({name: "test"})}),
]);

const validators = CorsValidation.factory(app);
const server = new Server(app, {validators});
return server.listen(10000);
}

Runtime.initialize(() => {
Logger.set([
new StdErrorTransport(debugLogFormat),
new StdOutTransport(debugLogFormat),
]);
});

Runtime.run(main);
13 changes: 9 additions & 4 deletions packages/util/src/assert.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import {AssertionError} from "./Exception";
import {isFunction} from "./is";
import {isFunction, isInstanceOf} from "./is";

export type AssertionDetailsAllowType = string | Record<any, any> | null | number;
export type AssertionDetails = (() => AssertionDetailsAllowType) | AssertionDetailsAllowType;
export type AssertionMessage = string | (() => string);
export type AssertionMessage = string | (() => string) | (() => Error);

function createAssertionError(message?: AssertionMessage, details?: AssertionDetails) {
const description = isFunction(message) ? message() : message;
return isInstanceOf(description, Error) ? description : new AssertionError(description, details);
}

export function assert(expr: unknown, message?: AssertionMessage, details?: AssertionDetails): asserts expr {
if (!expr) {
throw new AssertionError(isFunction(message) ? message() : message, details);
throw createAssertionError(message, details);
}
}

export function fails(expr: unknown, message?: AssertionMessage, details?: AssertionDetails): void {
if (!!expr) {
throw new AssertionError(isFunction(message) ? message() : message, details);
throw createAssertionError(message, details);
}
}

Expand Down
4 changes: 4 additions & 0 deletions packages/web/src/Transport/Request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ export class Request extends RequestAbstract {
}
}

public get origin(): string {
return this.headers.get("origin", "");
}

public validate(): boolean {
this.#validators.forEach((validator) => validator.validate(this));
return true;
Expand Down
56 changes: 47 additions & 9 deletions packages/web/src/Transport/Validation/CorsValidation.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import {Application, IRouteMatcher, PathRoute, RequestValidatorAbstract, RouteAbstract} from "@typesafeunit/app";
import {isFunction, isInstanceOf, isString} from "@typesafeunit/util";
import {
Application,
IRouteMatcher,
PathRoute,
RequestValidatorAbstract,
RouteAbstract,
RouteNotFound,
} from "@typesafeunit/app";
import {assert, isDefined, isFunction, isInstanceOf, isString} from "@typesafeunit/util";
import {Headers} from "../Headers";
import {ICorsOptions} from "../interfaces";
import {Request} from "../Request";
Expand All @@ -12,8 +19,8 @@ export class CorsValidation extends RequestValidatorAbstract<ICorsOptions> {
return this.options.origin;
}

public factory(app: Application<any, any>, options: ICorsOptions): CorsValidation {
const validator = new CorsValidation(options);
public static factory(app: Application<any, any>, options: ICorsOptions = {}): CorsValidation {
const validator = new this(options);
validator.updateRoutes(app);
return validator;
}
Expand All @@ -26,19 +33,31 @@ export class CorsValidation extends RequestValidatorAbstract<ICorsOptions> {
return;
}

let found = false;
for (const [matcher, method] of this.#table) {
const route = request.route.replace("OPTIONS", method);
if (matcher.test(route)) {
AccessControlAllowMethods.add(method);
found = true;
}
}

assert(found, () => new RouteNotFound("Not Found"));
const methods = [...AccessControlAllowMethods.values()];
const headers = this.getAccessControlHeaders(request, methods);
throw new NoContentResponse({headers: new Headers(headers)});
}

request.setResponseHeaders([["Access-Control-Allow-Origin", this.getOrigin(request)]]);
const setHeaders: [string, string][] = [
["Access-Control-Allow-Origin", this.getAccessControlOrigin(request)],
];

const vary = this.getVary();
if (vary) {
setHeaders.push(["Vary", vary]);
}

request.setResponseHeaders(setHeaders);
}

protected getAccessControlHeaders(request: Request, methods: string[]): [string, string][] {
Expand All @@ -47,17 +66,36 @@ export class CorsValidation extends RequestValidatorAbstract<ICorsOptions> {
"Content-Type, Accept, Authorization",
);

return [
["Access-Control-Allow-Origin", this.getOrigin(request)],
const headers: [string, string][] = [
["Access-Control-Allow-Origin", this.getAccessControlOrigin(request)],
["Access-Control-Allow-Headers", acRequestHeaders],
["Access-Control-Allow-Methods", methods.join(", ")],
["Access-Control-Max-Age", "86400"],
];

const vary = this.getVary();
if (vary) {
headers.push(["Vary", vary]);
}

if (isDefined(this.options.credentials)) {
headers.push(["Access-Control-Allow-Credentials", this.options.credentials ? "true" : "false"]);
}

return headers;
}

protected getVary(): string | undefined {
if (isString(this.options.origin) && this.options.origin === "origin") {
return "Origin";
}

return;
}

protected getOrigin(request: Request): string {
protected getAccessControlOrigin(request: Request): string {
if (isString(this.origin)) {
return this.origin;
return this.origin === "origin" ? request.origin : this.origin;
}

if (isFunction(this.origin)) {
Expand Down
3 changes: 2 additions & 1 deletion packages/web/src/Transport/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export type ServerHeadersResolver = (request: Request) => { [key: string]: strin
export type ServerRequestHandler<T = void> = (request: Request) => T;

export interface ICorsOptions {
origin: string | ServerRequestHandler<string>;
origin?: string | ServerRequestHandler<string> | "origin";
credentials?: boolean;
}

export interface IServerOptions {
Expand Down

0 comments on commit 5ac872b

Please sign in to comment.