diff --git a/src/adapter.ts b/src/adapter.ts index 46376fb6..03d54d36 100644 --- a/src/adapter.ts +++ b/src/adapter.ts @@ -15,7 +15,7 @@ import { Adapter, Helper, Model } from 'casbin'; import { Op } from 'sequelize'; import { Sequelize, SequelizeOptions } from 'sequelize-typescript'; -import { CasbinRule, updateCasbinRule } from './casbinRule'; +import { createCasbinRule, CasbinRule } from './casbinRule'; export interface SequelizeAdapterOptions extends SequelizeOptions { tableName?: string; @@ -30,6 +30,7 @@ export class SequelizeAdapter implements Adapter { private sequelize: Sequelize; private filtered = false; private autoCreateTable = true; + private CasbinRule: typeof CasbinRule; constructor(option: SequelizeAdapterOptions, autoCreateTable = true) { this.option = option; @@ -60,9 +61,12 @@ export class SequelizeAdapter implements Adapter { private async open(): Promise { this.sequelize = new Sequelize(this.option); - updateCasbinRule(this.option.tableName, this.option.schema); + this.CasbinRule = createCasbinRule( + this.option.tableName, + this.option.schema + ); // Set the property here await this.sequelize.authenticate(); - this.sequelize.addModels([CasbinRule]); + this.sequelize.addModels([this.CasbinRule]); if (this.autoCreateTable) { await this.createTable(); } @@ -90,7 +94,7 @@ export class SequelizeAdapter implements Adapter { * loadPolicy loads all policy rules from the storage. */ public async loadPolicy(model: Model): Promise { - const lines = await this.sequelize.getRepository(CasbinRule).findAll(); + const lines = await this.sequelize.getRepository(this.CasbinRule).findAll(); for (const line of lines) { this.loadPolicyLine(line, model); @@ -98,7 +102,7 @@ export class SequelizeAdapter implements Adapter { } private savePolicyLine(ptype: string, rule: string[]): CasbinRule { - const line = new CasbinRule(); + const line = new this.CasbinRule(); line.ptype = ptype; if (rule.length > 0) { @@ -130,7 +134,7 @@ export class SequelizeAdapter implements Adapter { await this.sequelize.transaction(async (tx) => { // truncate casbin table await this.sequelize - .getRepository(CasbinRule) + .getRepository(this.CasbinRule) .destroy({ where: {}, truncate: true, transaction: tx }); const lines: CasbinRule[] = []; @@ -151,7 +155,7 @@ export class SequelizeAdapter implements Adapter { } } - await CasbinRule.bulkCreate( + await this.CasbinRule.bulkCreate( lines.map((l) => l.get({ plain: true })), { transaction: tx } ); @@ -185,7 +189,7 @@ export class SequelizeAdapter implements Adapter { lines.push(line); } await this.sequelize.transaction(async (tx) => { - await CasbinRule.bulkCreate( + await this.CasbinRule.bulkCreate( lines.map((l) => l.get({ plain: true })), { transaction: tx } ); @@ -210,7 +214,7 @@ export class SequelizeAdapter implements Adapter { where[key] = line[key]; }); - await this.sequelize.getRepository(CasbinRule).destroy({ where }); + await this.sequelize.getRepository(this.CasbinRule).destroy({ where }); } /** @@ -234,7 +238,7 @@ export class SequelizeAdapter implements Adapter { }); await this.sequelize - .getRepository(CasbinRule) + .getRepository(this.CasbinRule) .destroy({ where, transaction: tx }); } }); @@ -271,7 +275,7 @@ export class SequelizeAdapter implements Adapter { }; const lines = await this.sequelize - .getRepository(CasbinRule) + .getRepository(this.CasbinRule) .findAll({ where }); lines.forEach((line) => this.loadPolicyLine(line, model)); @@ -287,7 +291,7 @@ export class SequelizeAdapter implements Adapter { fieldIndex: number, ...fieldValues: string[] ): Promise { - const line = new CasbinRule(); + const line = new this.CasbinRule(); line.ptype = ptype; const idx = fieldIndex + fieldValues.length; @@ -319,7 +323,7 @@ export class SequelizeAdapter implements Adapter { where[key] = line[key]; }); - await this.sequelize.getRepository(CasbinRule).destroy({ + await this.sequelize.getRepository(this.CasbinRule).destroy({ where, }); } diff --git a/src/casbinRule.ts b/src/casbinRule.ts index ef49574c..b4180612 100644 --- a/src/casbinRule.ts +++ b/src/casbinRule.ts @@ -44,12 +44,16 @@ export class CasbinRule extends Model { public v5: string; } -export function updateCasbinRule( +export function createCasbinRule( tableName = 'casbin_rule', schema?: string -): void { - const options = getOptions(CasbinRule.prototype); +): typeof CasbinRule { + class CustomCasbinRule extends CasbinRule {} + + const options = getOptions(CustomCasbinRule.prototype); options!.tableName = tableName; options!.schema = schema; - setOptions(CasbinRule.prototype, options!); + setOptions(CustomCasbinRule.prototype, options!); + + return CustomCasbinRule; }