Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add support for sqlalchemy 2.0 #3066

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Changelog

## 3.18.0

### Enhancements:

- Add support for `sqlalchemy 2.0` (mihran113)

## 3.17.6

- Switch to patched version of official `pynvml` (mihran113)
- Remove telemetry tracking (mihran113)

Expand Down
3 changes: 1 addition & 2 deletions aim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from aim.ext.notebook.notebook import load_ipython_extension
from aim.cli.manager.manager import run_process

from aim.utils.deprecation import python_version_deprecation_check, sqlalchemy_version_check
from aim.utils.deprecation import python_version_deprecation_check

python_version_deprecation_check()
sqlalchemy_version_check()
8 changes: 6 additions & 2 deletions aim/storage/structured/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def keys(self) -> list:
self._cached = True
return list(self._data.keys())

def empty_cache(self):
self._data.clear()
self._cached = False

def __setitem__(self, key, value):
assert self._cached
self._data[key] = value
Expand Down Expand Up @@ -56,7 +60,7 @@ def __init__(self, path: str, readonly: bool = False):
self.readonly = readonly
self.engine = create_engine(self.db_url,
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))))
self.session_cls = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=self.engine))
self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine))
self._upgraded = None

@classmethod
Expand Down Expand Up @@ -85,7 +89,7 @@ def caches(self):

def get_session(self, autocommit=True):
session = self.session_cls()
session.autocommit = autocommit
setattr(session, 'autocommit', autocommit)
return session

def run_upgrades(self):
Expand Down
73 changes: 43 additions & 30 deletions aim/storage/structured/sql_engine/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytz

from typing import Collection, Union, List, Optional
from sqlalchemy import delete
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError

Expand All @@ -27,6 +28,13 @@
from aim.storage.structured.sql_engine.utils import ModelMappedProperty as Property


def session_commit_or_flush(session):
if getattr(session, 'autocommit', True):
session.commit()
else:
session.flush()


def timestamp_or_none(dt):
if dt is None:
return None
Expand Down Expand Up @@ -71,14 +79,14 @@ def from_hash(cls, runhash: str, created_at, session) -> 'ModelMappedRun':
raise ValueError(f'Run with hash \'{runhash}\' already exists.')
run = RunModel(runhash, created_at)
session.add(run)
session.flush()
session_commit_or_flush(session)
return ModelMappedRun(run, session)

@classmethod
def delete_run(cls, runhash: str, session) -> bool:
try:
rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete()
session.flush()
session_commit_or_flush(session)
except Exception:
return False
return rows_affected > 0
Expand Down Expand Up @@ -141,18 +149,19 @@ def experiment(self) -> Union[str, SafeNone]:

@property
def info(self) -> Optional[IRunInfo]:
session = self._session
if self._model:
if self._model.info:
return ModelMappedRunInfo(self._model.info, self._session)
return ModelMappedRunInfo(self._model.info, session)
else:
info = RunInfoModel()

self._model.info = info
self._session.add(info)
self._session.add(self._model)
self._session.flush()
session.add(info)
session.add(self._model)
session_commit_or_flush(session)

return ModelMappedRunInfo(self._model.info, self._session)
return ModelMappedRunInfo(self._model.info, session)

@experiment.setter
def experiment(self, value: str):
Expand All @@ -170,7 +179,7 @@ def unsafe_set_exp():
session = self._session
unsafe_set_exp()
try:
session.flush()
session_commit_or_flush(session)
except IntegrityError:
session.rollback()
unsafe_set_exp()
Expand All @@ -196,7 +205,7 @@ def add_tag(self, value: str) -> None:
session.add(tag)
self._model.tags.append(tag)
session.add(self._model)
session.flush()
session_commit_or_flush(session)

def remove_tag(self, tag_name: str) -> bool:
session = self._session
Expand All @@ -207,7 +216,7 @@ def remove_tag(self, tag_name: str) -> bool:
tag_removed = True
break
session.add(self._model)
session.flush()
session_commit_or_flush(session)
return tag_removed

@property
Expand Down Expand Up @@ -248,13 +257,14 @@ def add_note(self, content: str):
note = NoteModel(content)
session.add(note)
self._model.notes.append(note)
session.flush()

audit_log = NoteAuditLogModel(action="Created", before=None, after=content)
audit_log.note_id = note.id
session.add(audit_log)
note.audit_logs.append(audit_log)

session.add(self._model)
session.flush()
session_commit_or_flush(session)

return note

Expand All @@ -267,11 +277,12 @@ def update_note(self, _id: int, content):
note.content = content

audit_log = NoteAuditLogModel(action="Updated", before=before, after=content)
audit_log.note_id = _id
session.add(audit_log)
note.audit_logs.append(audit_log)

session.add(note)
session.flush()

session_commit_or_flush(session)

return note

Expand All @@ -282,11 +293,11 @@ def remove_note(self, _id: int):
audit_log.note_id = _id
session.add(audit_log)

session.query(NoteModel).filter(
NoteModel.run_id == self._model.id,
NoteModel.id == _id,
).delete()
session.flush()
delete_stmnt = delete(NoteModel).where(NoteModel.run_id == self._model.id,
NoteModel.id == _id,)
session.execute(delete_stmnt)

session_commit_or_flush(session)


class ModelMappedExperiment(IExperiment, metaclass=ModelMappedClassMeta):
Expand Down Expand Up @@ -327,7 +338,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedExperiment':
raise ValueError(f'Experiment with name \'{name}\' already exists.')
exp = ExperimentModel(name)
session.add(exp)
session.flush()
session_commit_or_flush(session)
return ModelMappedExperiment(exp, session)

@property
Expand Down Expand Up @@ -400,13 +411,14 @@ def add_note(self, content: str):
note = NoteModel(content)
session.add(note)
self._model.notes.append(note)
session.flush()

audit_log = NoteAuditLogModel(action="Created", before=None, after=content)
audit_log.note_id = note.id
session.add(audit_log)
note.audit_logs.append(audit_log)

session.add(self._model)
session.flush()
session_commit_or_flush(session)

return note

Expand All @@ -419,11 +431,11 @@ def update_note(self, _id: int, content):
note.content = content

audit_log = NoteAuditLogModel(action="Updated", before=before, after=content)
audit_log.note_id = note.id
session.add(audit_log)
note.audit_logs.append(audit_log)

session.add(note)
session.flush()
session_commit_or_flush(session)

return note

Expand All @@ -434,11 +446,10 @@ def remove_note(self, _id: int):
audit_log.note_id = _id
session.add(audit_log)

session.query(NoteModel).filter(
NoteModel.experiment_id == self._model.id,
NoteModel.id == _id,
).delete()
session.flush()
delete_stmnt = delete(NoteModel).where(NoteModel.experiment_id == self._model.id,
NoteModel.id == _id, )
session.execute(delete_stmnt)
session_commit_or_flush(session)

def refresh_model(self):
self._session.refresh(self._model)
Expand Down Expand Up @@ -481,7 +492,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedTag':
raise ValueError(f'Tag with name \'{name}\' already exists.')
tag = TagModel(name)
session.add(tag)
session.flush()
session_commit_or_flush(session)
return ModelMappedTag(tag, session)

@classmethod
Expand Down Expand Up @@ -525,6 +536,7 @@ def delete(cls, _id: str, **kwargs) -> bool:
model_obj = session.query(TagModel).filter(TagModel.uuid == _id).first()
if model_obj:
session.delete(model_obj)
session_commit_or_flush(session)
return True
return False

Expand Down Expand Up @@ -592,6 +604,7 @@ def delete(cls, _id: str, **kwargs) -> bool:
model_obj = session.query(NoteModel).filter(NoteModel.id == _id).first()
if model_obj:
session.delete(model_obj)
session_commit_or_flush(session)
return True
return False

Expand Down
4 changes: 3 additions & 1 deletion aim/storage/structured/sql_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def setter(object_, value):
try:
setattr(object_._model, self.mapped_name, value)
object_._session.add(object_._model)
if object_._session.autocommit:
if getattr(object_._session, 'autocommit', True):
object_._session.commit()
else:
object_._session.flush()
except Exception:
direct_setter(object_, value)

Expand Down
10 changes: 0 additions & 10 deletions aim/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
DEFAULT_MSG_TEMPLATE = 'This functionality will be removed in'


def sqlalchemy_version_check():
import sqlalchemy
import packaging.version
from aim.__version__ import __version__
if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse('2.0.0'):
raise RuntimeError(f'Aim v{__version__} does not support sqlalchemy v{sqlalchemy.__version__}. '
f'Please check the following issue for further updates: '
f'https://github.com/aimhubio/aim/issues/2514')


def python_version_deprecation_check():
import sys
version_info = sys.version_info
Expand Down
2 changes: 1 addition & 1 deletion aim/web/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))),
connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
SessionLocal = sessionmaker(autoflush=False, bind=engine)
Base = declarative_base()


Expand Down
4 changes: 4 additions & 0 deletions aim/web/api/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def cleanup_repo_pools(self):
self.repo.container_view_pool.clear()
self.repo.persistent_pool.clear()

def cleanup_sql_caches(self):
for cache in self.repo.structured_db.caches.values():
cache.empty_cache()

def exists(self):
"""
Checks whether .aim repository is created
Expand Down
1 change: 1 addition & 0 deletions aim/web/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# cleanup repo pools after each api call
project = Project()
project.cleanup_repo_pools()
project.cleanup_sql_caches()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def package_files(directory):
'fastapi<1,>=0.69.0',
'jinja2<4,>=2.10.0',
'pytz>=2019.1',
'SQLAlchemy<2,>=1.4.1',
'SQLAlchemy>=1.4.1',
'uvicorn<1,>=0.12.0',
'Pillow>=8.0.0',
'protobuf<5,>=3.9.2',
Expand Down
Loading