From 96d44fbccd721bc28d492bdf374f9e69bcb416a4 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Fri, 24 Nov 2023 04:21:50 +0400 Subject: [PATCH 1/3] [feat] Add support for `sqlalchemy 2.0` --- CHANGELOG.md | 7 +++ aim/__init__.py | 3 +- aim/storage/structured/db.py | 6 +- aim/storage/structured/sql_engine/entities.py | 56 ++++++++++--------- aim/utils/deprecation.py | 10 ---- aim/web/api/projects/project.py | 4 ++ aim/web/api/utils.py | 1 + setup.py | 2 +- 8 files changed, 49 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f25507909b..558d488834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ # Changelog +## 3.18.0 + +### Enhancements: + +- Add support for `sqlalchemy 2.0` (mihran113) + ## 3.17.6 + - Switch to patched version of official `pynvml` (mihran113) - Remove telemetry tracking (mihran113) diff --git a/aim/__init__.py b/aim/__init__.py index 7f10698f92..5f330e258c 100644 --- a/aim/__init__.py +++ b/aim/__init__.py @@ -3,7 +3,6 @@ from aim.ext.notebook.notebook import load_ipython_extension from aim.cli.manager.manager import run_process -from aim.utils.deprecation import python_version_deprecation_check, sqlalchemy_version_check +from aim.utils.deprecation import python_version_deprecation_check python_version_deprecation_check() -sqlalchemy_version_check() diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index bc48df44a8..5261eb01c2 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -29,6 +29,10 @@ def keys(self) -> list: self._cached = True return list(self._data.keys()) + def empty_cache(self): + self._data.clear() + self._cached = False + def __setitem__(self, key, value): assert self._cached self._data[key] = value @@ -56,7 +60,7 @@ def __init__(self, path: str, readonly: bool = False): self.readonly = readonly self.engine = create_engine(self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING)))) - self.session_cls = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=self.engine)) + self.session_cls = scoped_session(sessionmaker(self.engine)) self._upgraded = None @classmethod diff --git a/aim/storage/structured/sql_engine/entities.py b/aim/storage/structured/sql_engine/entities.py index 168f057eea..6017afeaf0 100644 --- a/aim/storage/structured/sql_engine/entities.py +++ b/aim/storage/structured/sql_engine/entities.py @@ -1,6 +1,7 @@ import pytz from typing import Collection, Union, List, Optional +from sqlalchemy import delete from sqlalchemy.orm import joinedload from sqlalchemy.exc import IntegrityError @@ -71,14 +72,14 @@ def from_hash(cls, runhash: str, created_at, session) -> 'ModelMappedRun': raise ValueError(f'Run with hash \'{runhash}\' already exists.') run = RunModel(runhash, created_at) session.add(run) - session.flush() + session.commit() return ModelMappedRun(run, session) @classmethod def delete_run(cls, runhash: str, session) -> bool: try: rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() - session.flush() + session.commit() except Exception: return False return rows_affected > 0 @@ -150,7 +151,7 @@ def info(self) -> Optional[IRunInfo]: self._model.info = info self._session.add(info) self._session.add(self._model) - self._session.flush() + self._session.commit() return ModelMappedRunInfo(self._model.info, self._session) @@ -170,7 +171,7 @@ def unsafe_set_exp(): session = self._session unsafe_set_exp() try: - session.flush() + session.commit() except IntegrityError: session.rollback() unsafe_set_exp() @@ -196,7 +197,7 @@ def add_tag(self, value: str) -> None: session.add(tag) self._model.tags.append(tag) session.add(self._model) - session.flush() + session.commit() def remove_tag(self, tag_name: str) -> bool: session = self._session @@ -207,7 +208,7 @@ def remove_tag(self, tag_name: str) -> bool: tag_removed = True break session.add(self._model) - session.flush() + session.commit() return tag_removed @property @@ -248,13 +249,14 @@ def add_note(self, content: str): note = NoteModel(content) session.add(note) self._model.notes.append(note) + session.flush() audit_log = NoteAuditLogModel(action="Created", before=None, after=content) + audit_log.note_id = note.id session.add(audit_log) - note.audit_logs.append(audit_log) session.add(self._model) - session.flush() + session.commit() return note @@ -267,11 +269,11 @@ def update_note(self, _id: int, content): note.content = content audit_log = NoteAuditLogModel(action="Updated", before=before, after=content) + audit_log.note_id = _id session.add(audit_log) - note.audit_logs.append(audit_log) session.add(note) - session.flush() + session.commit() return note @@ -282,11 +284,11 @@ def remove_note(self, _id: int): audit_log.note_id = _id session.add(audit_log) - session.query(NoteModel).filter( - NoteModel.run_id == self._model.id, - NoteModel.id == _id, - ).delete() - session.flush() + delete_stmnt = delete(NoteModel).where(NoteModel.run_id == self._model.id, + NoteModel.id == _id,) + session.execute(delete_stmnt) + + session.commit() class ModelMappedExperiment(IExperiment, metaclass=ModelMappedClassMeta): @@ -327,7 +329,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedExperiment': raise ValueError(f'Experiment with name \'{name}\' already exists.') exp = ExperimentModel(name) session.add(exp) - session.flush() + session.commit() return ModelMappedExperiment(exp, session) @property @@ -400,13 +402,14 @@ def add_note(self, content: str): note = NoteModel(content) session.add(note) self._model.notes.append(note) + session.flush() audit_log = NoteAuditLogModel(action="Created", before=None, after=content) + audit_log.note_id = note.id session.add(audit_log) - note.audit_logs.append(audit_log) session.add(self._model) - session.flush() + session.commit() return note @@ -419,11 +422,11 @@ def update_note(self, _id: int, content): note.content = content audit_log = NoteAuditLogModel(action="Updated", before=before, after=content) + audit_log.note_id = note.id session.add(audit_log) - note.audit_logs.append(audit_log) session.add(note) - session.flush() + session.commit() return note @@ -434,11 +437,10 @@ def remove_note(self, _id: int): audit_log.note_id = _id session.add(audit_log) - session.query(NoteModel).filter( - NoteModel.experiment_id == self._model.id, - NoteModel.id == _id, - ).delete() - session.flush() + delete_stmnt = delete(NoteModel).where(NoteModel.experiment_id == self._model.id, + NoteModel.id == _id, ) + session.execute(delete_stmnt) + session.commit() def refresh_model(self): self._session.refresh(self._model) @@ -481,7 +483,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedTag': raise ValueError(f'Tag with name \'{name}\' already exists.') tag = TagModel(name) session.add(tag) - session.flush() + session.commit() return ModelMappedTag(tag, session) @classmethod @@ -525,6 +527,7 @@ def delete(cls, _id: str, **kwargs) -> bool: model_obj = session.query(TagModel).filter(TagModel.uuid == _id).first() if model_obj: session.delete(model_obj) + session.commit() return True return False @@ -592,6 +595,7 @@ def delete(cls, _id: str, **kwargs) -> bool: model_obj = session.query(NoteModel).filter(NoteModel.id == _id).first() if model_obj: session.delete(model_obj) + session.commit() return True return False diff --git a/aim/utils/deprecation.py b/aim/utils/deprecation.py index d8096c9406..f654709623 100644 --- a/aim/utils/deprecation.py +++ b/aim/utils/deprecation.py @@ -5,16 +5,6 @@ DEFAULT_MSG_TEMPLATE = 'This functionality will be removed in' -def sqlalchemy_version_check(): - import sqlalchemy - import packaging.version - from aim.__version__ import __version__ - if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse('2.0.0'): - raise RuntimeError(f'Aim v{__version__} does not support sqlalchemy v{sqlalchemy.__version__}. ' - f'Please check the following issue for further updates: ' - f'https://github.com/aimhubio/aim/issues/2514') - - def python_version_deprecation_check(): import sys version_info = sys.version_info diff --git a/aim/web/api/projects/project.py b/aim/web/api/projects/project.py index 9b918e2cf3..ebd0606b5a 100644 --- a/aim/web/api/projects/project.py +++ b/aim/web/api/projects/project.py @@ -21,6 +21,10 @@ def cleanup_repo_pools(self): self.repo.container_view_pool.clear() self.repo.persistent_pool.clear() + def cleanup_sql_caches(self): + for cache in self.repo.structured_db.caches.values(): + cache.empty_cache() + def exists(self): """ Checks whether .aim repository is created diff --git a/aim/web/api/utils.py b/aim/web/api/utils.py index bbdade5916..844523557a 100644 --- a/aim/web/api/utils.py +++ b/aim/web/api/utils.py @@ -58,3 +58,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # cleanup repo pools after each api call project = Project() project.cleanup_repo_pools() + project.cleanup_sql_caches() diff --git a/setup.py b/setup.py index 7a6411b156..444a44fbee 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ def package_files(directory): 'fastapi<1,>=0.69.0', 'jinja2<4,>=2.10.0', 'pytz>=2019.1', - 'SQLAlchemy<2,>=1.4.1', + 'SQLAlchemy>=1.4.1', 'uvicorn<1,>=0.12.0', 'Pillow>=8.0.0', 'protobuf<5,>=3.9.2', From be32d45a2b05416c9a6fd17dd3e9f49fd90d0b60 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Fri, 24 Nov 2023 22:04:30 +0400 Subject: [PATCH 2/3] Mimic old autocommit options --- aim/storage/structured/db.py | 4 +- aim/storage/structured/sql_engine/entities.py | 49 +++++++++++-------- aim/web/api/db.py | 2 +- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index 5261eb01c2..3ffb070eae 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -60,7 +60,7 @@ def __init__(self, path: str, readonly: bool = False): self.readonly = readonly self.engine = create_engine(self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING)))) - self.session_cls = scoped_session(sessionmaker(self.engine)) + self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) self._upgraded = None @classmethod @@ -89,7 +89,7 @@ def caches(self): def get_session(self, autocommit=True): session = self.session_cls() - session.autocommit = autocommit + setattr(session, 'autocommit', autocommit) return session def run_upgrades(self): diff --git a/aim/storage/structured/sql_engine/entities.py b/aim/storage/structured/sql_engine/entities.py index 6017afeaf0..87ed93e049 100644 --- a/aim/storage/structured/sql_engine/entities.py +++ b/aim/storage/structured/sql_engine/entities.py @@ -28,6 +28,13 @@ from aim.storage.structured.sql_engine.utils import ModelMappedProperty as Property +def session_commit_or_flush(session): + if getattr(session, 'autocommit', True): + session.commit() + else: + session.flush() + + def timestamp_or_none(dt): if dt is None: return None @@ -72,14 +79,14 @@ def from_hash(cls, runhash: str, created_at, session) -> 'ModelMappedRun': raise ValueError(f'Run with hash \'{runhash}\' already exists.') run = RunModel(runhash, created_at) session.add(run) - session.commit() + session_commit_or_flush(session) return ModelMappedRun(run, session) @classmethod def delete_run(cls, runhash: str, session) -> bool: try: rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() - session.commit() + session_commit_or_flush(session) except Exception: return False return rows_affected > 0 @@ -142,18 +149,19 @@ def experiment(self) -> Union[str, SafeNone]: @property def info(self) -> Optional[IRunInfo]: + session = self._session if self._model: if self._model.info: - return ModelMappedRunInfo(self._model.info, self._session) + return ModelMappedRunInfo(self._model.info, session) else: info = RunInfoModel() self._model.info = info - self._session.add(info) - self._session.add(self._model) - self._session.commit() + session.add(info) + session.add(self._model) + session_commit_or_flush(session) - return ModelMappedRunInfo(self._model.info, self._session) + return ModelMappedRunInfo(self._model.info, session) @experiment.setter def experiment(self, value: str): @@ -171,7 +179,7 @@ def unsafe_set_exp(): session = self._session unsafe_set_exp() try: - session.commit() + session_commit_or_flush(session) except IntegrityError: session.rollback() unsafe_set_exp() @@ -197,7 +205,7 @@ def add_tag(self, value: str) -> None: session.add(tag) self._model.tags.append(tag) session.add(self._model) - session.commit() + session_commit_or_flush(session) def remove_tag(self, tag_name: str) -> bool: session = self._session @@ -208,7 +216,7 @@ def remove_tag(self, tag_name: str) -> bool: tag_removed = True break session.add(self._model) - session.commit() + session_commit_or_flush(session) return tag_removed @property @@ -256,7 +264,7 @@ def add_note(self, content: str): session.add(audit_log) session.add(self._model) - session.commit() + session_commit_or_flush(session) return note @@ -273,7 +281,8 @@ def update_note(self, _id: int, content): session.add(audit_log) session.add(note) - session.commit() + + session_commit_or_flush(session) return note @@ -288,7 +297,7 @@ def remove_note(self, _id: int): NoteModel.id == _id,) session.execute(delete_stmnt) - session.commit() + session_commit_or_flush(session) class ModelMappedExperiment(IExperiment, metaclass=ModelMappedClassMeta): @@ -329,7 +338,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedExperiment': raise ValueError(f'Experiment with name \'{name}\' already exists.') exp = ExperimentModel(name) session.add(exp) - session.commit() + session_commit_or_flush(session) return ModelMappedExperiment(exp, session) @property @@ -409,7 +418,7 @@ def add_note(self, content: str): session.add(audit_log) session.add(self._model) - session.commit() + session_commit_or_flush(session) return note @@ -426,7 +435,7 @@ def update_note(self, _id: int, content): session.add(audit_log) session.add(note) - session.commit() + session_commit_or_flush(session) return note @@ -440,7 +449,7 @@ def remove_note(self, _id: int): delete_stmnt = delete(NoteModel).where(NoteModel.experiment_id == self._model.id, NoteModel.id == _id, ) session.execute(delete_stmnt) - session.commit() + session_commit_or_flush(session) def refresh_model(self): self._session.refresh(self._model) @@ -483,7 +492,7 @@ def from_name(cls, name: str, session) -> 'ModelMappedTag': raise ValueError(f'Tag with name \'{name}\' already exists.') tag = TagModel(name) session.add(tag) - session.commit() + session_commit_or_flush(session) return ModelMappedTag(tag, session) @classmethod @@ -527,7 +536,7 @@ def delete(cls, _id: str, **kwargs) -> bool: model_obj = session.query(TagModel).filter(TagModel.uuid == _id).first() if model_obj: session.delete(model_obj) - session.commit() + session_commit_or_flush(session) return True return False @@ -595,7 +604,7 @@ def delete(cls, _id: str, **kwargs) -> bool: model_obj = session.query(NoteModel).filter(NoteModel.id == _id).first() if model_obj: session.delete(model_obj) - session.commit() + session_commit_or_flush(session) return True return False diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 6ba9d37aaf..452505b791 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -14,7 +14,7 @@ echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), connect_args={"check_same_thread": False} ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +SessionLocal = sessionmaker(autoflush=False, bind=engine) Base = declarative_base() From d32ec6e03b54220a4d1c5e824da9eea651541fef Mon Sep 17 00:00:00 2001 From: mihran113 Date: Fri, 24 Nov 2023 22:08:04 +0400 Subject: [PATCH 3/3] Missed changes --- aim/storage/structured/sql_engine/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aim/storage/structured/sql_engine/utils.py b/aim/storage/structured/sql_engine/utils.py index adce4fb25e..92ed706db0 100644 --- a/aim/storage/structured/sql_engine/utils.py +++ b/aim/storage/structured/sql_engine/utils.py @@ -45,8 +45,10 @@ def setter(object_, value): try: setattr(object_._model, self.mapped_name, value) object_._session.add(object_._model) - if object_._session.autocommit: + if getattr(object_._session, 'autocommit', True): object_._session.commit() + else: + object_._session.flush() except Exception: direct_setter(object_, value)