Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add db alias #21

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def hello(request):
A string containing the file location of your casbin model.

### `CASBIN_ADAPTER`
A string containing the adapter import path. Defaults to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`
A string containing the adapter import path. Default to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`

### `CASBIN_ADAPTER_ARGS`
A tuple of arguments to be passed into the constructor of the adapter specified
Expand All @@ -80,6 +80,9 @@ E.g. if you wish to use the file adapter
set the adapter to `casbin.persist.adapters.FileAdapter` and use
`CASBIN_ADAPTER_ARGS = ('path/to/policy_file.csv',)`

### `CASBIN_DB_ALIAS`
hsluoyz marked this conversation as resolved.
Show resolved Hide resolved
The database the adapter uses. Default to "default".

### `CASBIN_WATCHER`
Watcher instance to be set as the watcher on the enforcer instance.

Expand Down
13 changes: 8 additions & 5 deletions casbin_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""

def __init__(self, db_alias="default"):
self.db_alias = db_alias

def load_policy(self, model):
"""loads all policy rules from the storage."""
try:
lines = CasbinRule.objects.all()
lines = CasbinRule.objects.using(self.db_alias).all()
hsluoyz marked this conversation as resolved.
Show resolved Hide resolved

for line in lines:
persist.load_policy_line(str(line), model)
Expand All @@ -41,7 +44,7 @@ def save_policy(self, model):
"""saves all policy rules to the storage."""
# See https://casbin.org/docs/en/adapters#autosave
# for why this is deleting all rules
CasbinRule.objects.all().delete()
CasbinRule.objects.using(self.db_alias).all().delete()

lines = []
for sec in ["p", "g"]:
Expand All @@ -50,7 +53,7 @@ def save_policy(self, model):
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
lines.append(self._create_policy_line(ptype, rule))
CasbinRule.objects.bulk_create(lines)
CasbinRule.objects.using(self.db_alias).bulk_create(lines)
return True

def add_policy(self, sec, ptype, rule):
Expand All @@ -63,7 +66,7 @@ def remove_policy(self, sec, ptype, rule):
query_params = {"ptype": ptype}
for i, v in enumerate(rule):
query_params["v{}".format(i)] = v
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False

def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
Expand All @@ -77,5 +80,5 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return False
for i, v in enumerate(field_values):
query_params["v{}".format(i + field_index)] = v
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False
2 changes: 0 additions & 2 deletions casbin_adapter/apps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from django.apps import AppConfig
from django.db import connection
from django.db.utils import OperationalError, ProgrammingError


class CasbinAdapterConfig(AppConfig):
Expand Down
41 changes: 27 additions & 14 deletions casbin_adapter/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from casbin import Enforcer

from .adapter import Adapter
from .utils import import_class

logger = logging.getLogger(__name__)


class ProxyEnforcer(Enforcer):
_initialized = False
db_alias = "default"

def __init__(self, *args, **kwargs):
if self._initialized:
Expand All @@ -27,8 +27,9 @@ def _load(self):
model = getattr(settings, "CASBIN_MODEL")
adapter_loc = getattr(settings, "CASBIN_ADAPTER", "casbin_adapter.adapter.Adapter")
adapter_args = getattr(settings, "CASBIN_ADAPTER_ARGS", tuple())
self.db_alias = getattr(settings, "CASBIN_DB_ALIAS", "default")
Adapter = import_class(adapter_loc)
adapter = Adapter(*adapter_args)
adapter = Adapter(self.db_alias, *adapter_args)

super().__init__(model, adapter)
logger.debug("Casbin enforcer initialised")
Expand All @@ -44,7 +45,7 @@ def _load(self):
def __getattribute__(self, name):
safe_methods = ["__init__", "_load", "_initialized"]
if not super().__getattribute__("_initialized") and name not in safe_methods:
initialize_enforcer()
initialize_enforcer(self.db_alias)
if not super().__getattribute__("_initialized"):
raise Exception(
(
Expand All @@ -59,17 +60,29 @@ def __getattribute__(self, name):
enforcer = ProxyEnforcer()


def initialize_enforcer():
def initialize_enforcer(db_alias=None):
try:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
if row:
enforcer._load()
row = None
if db_alias:
with connection[db_alias].cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
else:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()

if row:
enforcer._load()
except (OperationalError, ProgrammingError):
pass
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
casbin==1.16.10
casbin>=1.16.10
Django

3 changes: 1 addition & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-r requirements.txt
setuptools==60.2.0
simpleeval==0.9.12
setuptools==60.2.0
1 change: 0 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import casbin
import simpleeval
from unittest import TestCase

from django.test import TestCase
from casbin_adapter.models import CasbinRule
Expand Down