diff --git a/README.md b/README.md index c79296a..3b57950 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ def hello(request): A string containing the file location of your casbin model. ### `CASBIN_ADAPTER` -A string containing the adapter import path. Defaults to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter` +A string containing the adapter import path. Default to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter` ### `CASBIN_ADAPTER_ARGS` A tuple of arguments to be passed into the constructor of the adapter specified @@ -80,6 +80,9 @@ E.g. if you wish to use the file adapter set the adapter to `casbin.persist.adapters.FileAdapter` and use `CASBIN_ADAPTER_ARGS = ('path/to/policy_file.csv',)` +### `CASBIN_DB_ALIAS` +The database the adapter uses. Default to "default". + ### `CASBIN_WATCHER` Watcher instance to be set as the watcher on the enforcer instance. diff --git a/casbin_adapter/adapter.py b/casbin_adapter/adapter.py index 44c5461..0984d41 100644 --- a/casbin_adapter/adapter.py +++ b/casbin_adapter/adapter.py @@ -11,10 +11,13 @@ class Adapter(persist.Adapter): """the interface for Casbin adapters.""" + def __init__(self, db_alias="default"): + self.db_alias = db_alias + def load_policy(self, model): """loads all policy rules from the storage.""" try: - lines = CasbinRule.objects.all() + lines = CasbinRule.objects.using(self.db_alias).all() for line in lines: persist.load_policy_line(str(line), model) @@ -41,7 +44,7 @@ def save_policy(self, model): """saves all policy rules to the storage.""" # See https://casbin.org/docs/en/adapters#autosave # for why this is deleting all rules - CasbinRule.objects.all().delete() + CasbinRule.objects.using(self.db_alias).all().delete() lines = [] for sec in ["p", "g"]: @@ -50,7 +53,7 @@ def save_policy(self, model): for ptype, ast in model.model[sec].items(): for rule in ast.policy: lines.append(self._create_policy_line(ptype, rule)) - CasbinRule.objects.bulk_create(lines) + CasbinRule.objects.using(self.db_alias).bulk_create(lines) return True def add_policy(self, sec, ptype, rule): @@ -63,7 +66,7 @@ def remove_policy(self, sec, ptype, rule): query_params = {"ptype": ptype} for i, v in enumerate(rule): query_params["v{}".format(i)] = v - rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete() + rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete() return True if rows_deleted > 0 else False def remove_filtered_policy(self, sec, ptype, field_index, *field_values): @@ -77,5 +80,5 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): return False for i, v in enumerate(field_values): query_params["v{}".format(i + field_index)] = v - rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete() + rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete() return True if rows_deleted > 0 else False diff --git a/casbin_adapter/apps.py b/casbin_adapter/apps.py index c92ab63..75273ce 100644 --- a/casbin_adapter/apps.py +++ b/casbin_adapter/apps.py @@ -1,6 +1,4 @@ from django.apps import AppConfig -from django.db import connection -from django.db.utils import OperationalError, ProgrammingError class CasbinAdapterConfig(AppConfig): diff --git a/casbin_adapter/enforcer.py b/casbin_adapter/enforcer.py index 0488bf7..d136408 100644 --- a/casbin_adapter/enforcer.py +++ b/casbin_adapter/enforcer.py @@ -5,7 +5,6 @@ from casbin import Enforcer -from .adapter import Adapter from .utils import import_class logger = logging.getLogger(__name__) @@ -13,6 +12,7 @@ class ProxyEnforcer(Enforcer): _initialized = False + db_alias = "default" def __init__(self, *args, **kwargs): if self._initialized: @@ -27,8 +27,9 @@ def _load(self): model = getattr(settings, "CASBIN_MODEL") adapter_loc = getattr(settings, "CASBIN_ADAPTER", "casbin_adapter.adapter.Adapter") adapter_args = getattr(settings, "CASBIN_ADAPTER_ARGS", tuple()) + self.db_alias = getattr(settings, "CASBIN_DB_ALIAS", "default") Adapter = import_class(adapter_loc) - adapter = Adapter(*adapter_args) + adapter = Adapter(self.db_alias, *adapter_args) super().__init__(model, adapter) logger.debug("Casbin enforcer initialised") @@ -44,7 +45,7 @@ def _load(self): def __getattribute__(self, name): safe_methods = ["__init__", "_load", "_initialized"] if not super().__getattribute__("_initialized") and name not in safe_methods: - initialize_enforcer() + initialize_enforcer(self.db_alias) if not super().__getattribute__("_initialized"): raise Exception( ( @@ -59,17 +60,29 @@ def __getattribute__(self, name): enforcer = ProxyEnforcer() -def initialize_enforcer(): +def initialize_enforcer(db_alias=None): try: - with connection.cursor() as cursor: - cursor.execute( - """ - SELECT app, name applied FROM django_migrations - WHERE app = 'casbin_adapter' AND name = '0001_initial'; - """ - ) - row = cursor.fetchone() - if row: - enforcer._load() + row = None + if db_alias: + with connection[db_alias].cursor() as cursor: + cursor.execute( + """ + SELECT app, name applied FROM django_migrations + WHERE app = 'casbin_adapter' AND name = '0001_initial'; + """ + ) + row = cursor.fetchone() + else: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT app, name applied FROM django_migrations + WHERE app = 'casbin_adapter' AND name = '0001_initial'; + """ + ) + row = cursor.fetchone() + + if row: + enforcer._load() except (OperationalError, ProgrammingError): pass diff --git a/requirements.txt b/requirements.txt index 7b56393..9d2bde5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -casbin==1.16.10 +casbin>=1.16.10 Django diff --git a/requirements_dev.txt b/requirements_dev.txt index b2dba3b..65ac54c 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,2 @@ -r requirements.txt -setuptools==60.2.0 -simpleeval==0.9.12 \ No newline at end of file +setuptools==60.2.0 \ No newline at end of file diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2da9795..7407a41 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,7 +1,6 @@ import os import casbin import simpleeval -from unittest import TestCase from django.test import TestCase from casbin_adapter.models import CasbinRule