Skip to content

Commit

Permalink
Refactor Sessions to clarify their usage
Browse files Browse the repository at this point in the history
DBSession has been renamed to HandlerSession.

The following exist now:
- HandlerSession: a DB session scoped to the current handler;
  should therefore only be used inside handlers.
- ThreadSession: a DB session scoped to the current thread;
  should be used outside of handlers, e.g. when firing off
  tasks in threads, or inside of services.

A special HandlerSession, with knowledge of the current user, and that
verifies permissions on commit, is available as `self.Session` on the
base handler. This is the *preferred way of accessing* HandlerSession.

To make it easier to get hold of the current engine, the Sessions each
have a `.engine` attribute: `HandlerSession.engine`.

This avoids having to access the engine as
`HandlerSession.session_factory.kw["bind"]`.
  • Loading branch information
stefanv committed Dec 8, 2023
1 parent 57b7faa commit 3fdfdda
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
pass_filenames: true
exclude: baselayer|node_modules|static
- repo: https://github.com/pycqa/flake8
rev: 3.8.4
rev: 6.1.0
hooks:
- id: flake8
pass_filenames: true
Expand Down
4 changes: 2 additions & 2 deletions app/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import joinedload

from baselayer.app.custom_exceptions import AccessError # noqa: F401
from baselayer.app.models import DBSession, Role, Token, User # noqa: F401
from baselayer.app.models import HandlerSession, Role, Token, User # noqa: F401


def auth_or_token(method):
Expand All @@ -26,7 +26,7 @@ def wrapper(self, *args, **kwargs):
token_header = self.request.headers.get("Authorization", None)
if token_header is not None and token_header.startswith("token "):
token_id = token_header.replace("token", "").strip()
with DBSession() as session:
with HandlerSession() as session:
token = session.scalars(
sa.select(Token)
.options(
Expand Down
30 changes: 18 additions & 12 deletions app/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from ..env import load_env
from ..flow import Flow
from ..json_util import to_json
from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id
from ..models import (
HandlerSession,
User,
VerifiedSession,
bulk_verify,
session_context_id,
)

env, cfg = load_env()
log = make_log("basehandler")
Expand All @@ -49,7 +55,7 @@ def get_current_user(self):
user_id = int(self.user_id())
oauth_uid = self.get_secure_cookie("user_oauth_uid")
if user_id and oauth_uid:
with DBSession() as session:
with HandlerSession() as session:
try:
user = session.scalars(
sqlalchemy.select(User).where(User.id == user_id)
Expand All @@ -74,7 +80,7 @@ def get_current_user(self):
return None

def login_user(self, user):
with DBSession() as session:
with HandlerSession() as session:
try:
self.set_secure_cookie("user_id", str(user.id))
user = session.scalars(
Expand Down Expand Up @@ -120,7 +126,7 @@ def log_exception(self, typ=None, value=None, tb=None):
)

def on_finish(self):
DBSession.remove()
HandlerSession.remove()


class BaseHandler(PSABaseHandler):
Expand Down Expand Up @@ -153,7 +159,7 @@ def Session(self):
# must merge the user object with the current session
# ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#adding-new-or-existing-items
session.add(self.current_user)
session.bind = DBSession.session_factory.kw["bind"]
session.bind = HandlerSession.engine
yield session

def verify_permissions(self):
Expand All @@ -164,20 +170,20 @@ def verify_permissions(self):
"""

# get items to be inserted
new_rows = [row for row in DBSession().new]
new_rows = [row for row in HandlerSession().new]

# get items to be updated
updated_rows = [
row for row in DBSession().dirty if DBSession().is_modified(row)
row for row in HandlerSession().dirty if HandlerSession().is_modified(row)
]

# get items to be deleted
deleted_rows = [row for row in DBSession().deleted]
deleted_rows = [row for row in HandlerSession().deleted]

# get items that were read
read_rows = [
row
for row in set(DBSession().identity_map.values())
for row in set(HandlerSession().identity_map.values())
- (set(updated_rows) | set(new_rows) | set(deleted_rows))
]

Expand All @@ -194,15 +200,15 @@ def verify_permissions(self):
# update transaction state in DB, but don't commit yet. this updates
# or adds rows in the database and uses their new state in joins,
# for permissions checking purposes.
DBSession().flush()
HandlerSession().flush()
bulk_verify("create", new_rows, self.current_user)

def verify_and_commit(self):
"""Verify permissions on the current database session and commit if
successful, otherwise raise an AccessError.
"""
self.verify_permissions()
DBSession().commit()
HandlerSession().commit()

def prepare(self):
self.cfg = self.application.cfg
Expand All @@ -225,7 +231,7 @@ def prepare(self):
N = 5
for i in range(1, N + 1):
try:
assert DBSession.session_factory.kw["bind"] is not None
assert HandlerSession.engine is not None
except Exception as e:
if i == N:
raise e
Expand Down
17 changes: 8 additions & 9 deletions app/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def status(message):
else:
print(f"\r[✓] {message}")
finally:
models.DBSession().commit()
models.HandlerSession().commit()


def drop_tables():
conn = models.DBSession.session_factory.kw["bind"]
print(f"Dropping tables on database {conn.url.database}")
engine = models.HandlerSession.engine
print(f"Dropping tables on database {engine.url.database}")
meta = sa.MetaData()
meta.reflect(bind=conn)
meta.drop_all(bind=conn)
meta.reflect(bind=engine)
meta.drop_all(bind=engine)


def create_tables(retry=5, add=True):
Expand All @@ -45,17 +45,16 @@ def create_tables(retry=5, add=True):
tables.
"""
conn = models.DBSession.session_factory.kw["bind"]
tables = models.Base.metadata.sorted_tables
if tables and not add:
print("Existing tables found; not creating additional tables")
return

for i in range(1, retry + 1):
try:
conn = models.DBSession.session_factory.kw["bind"]
print(f"Creating tables on database {conn.url.database}")
models.Base.metadata.create_all(conn)
engine = models.HandlerSession.engine
print(f"Creating tables on database {engine.url.database}")
models.Base.metadata.create_all(engine)

table_list = ", ".join(list(models.Base.metadata.tables.keys()))
print(f"Refreshed tables: {table_list}")
Expand Down
Loading

0 comments on commit 3fdfdda

Please sign in to comment.