Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serverless: add create_objects and fix several issues #178

Merged
merged 34 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1895719
add missing dependencies for serverless
viseshrp Oct 9, 2024
aaa3e9a
reset flag before save
viseshrp Oct 9, 2024
51f72c5
fix some type hints
viseshrp Oct 9, 2024
d3abc60
fix formatting
viseshrp Oct 9, 2024
597aac5
Delete top.html
viseshrp Oct 9, 2024
77566b1
Delete test_serverless.py
viseshrp Oct 10, 2024
598a75f
formatting
viseshrp Oct 10, 2024
98d3d0f
disable docs-style
viseshrp Oct 10, 2024
cfe846c
handle database errors with save()
viseshrp Oct 11, 2024
ba3c552
make sure to close file handles properly
viseshrp Oct 18, 2024
79c1d7f
pass generic kwargs to find()
viseshrp Oct 18, 2024
706248d
Merge branch 'main' into feat/serverless
viseshrp Oct 21, 2024
73be3d6
use the file handle properly for file payload saves
viseshrp Oct 21, 2024
9714097
add create_objects and copy_objects stubs
viseshrp Oct 21, 2024
886f03f
Merge branch 'main' into feat/serverless
viseshrp Oct 21, 2024
8a2f941
migrate multiple databases at once
viseshrp Oct 23, 2024
86ea319
add as_dict and get_or_create
viseshrp Oct 23, 2024
efbf622
fix as_dict
viseshrp Oct 23, 2024
b3988cb
fix Session attrs
viseshrp Oct 23, 2024
c061a07
allow setting guid
viseshrp Oct 23, 2024
dedf8ee
don't reset the orm object before save
viseshrp Oct 25, 2024
0b8869b
add a few helper methods
viseshrp Oct 26, 2024
82c80f4
fix some exceptions in copy_objects
viseshrp Oct 27, 2024
ff96237
Merge branch 'main' into feat/serverless
viseshrp Oct 28, 2024
447ca14
fix style
viseshrp Oct 28, 2024
2367c74
add a reinit method for integrity
viseshrp Nov 1, 2024
f1d275e
fix save in file payloads
viseshrp Nov 1, 2024
484c0c1
fix ObjectSet.delete
viseshrp Nov 1, 2024
a8b4dc6
fix copying of items
viseshrp Nov 1, 2024
b5a28dd
Update .gitignore
viseshrp Nov 6, 2024
cff2ce7
Merge branch 'main' into feat/serverless
viseshrp Nov 12, 2024
e49ae63
fix imports
viseshrp Nov 12, 2024
03d0f50
revert copy_objects
viseshrp Nov 12, 2024
860f336
fix _get_db_dir
viseshrp Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ansys/dynamicreporting/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,9 @@ class MultipleObjectsReturnedError(ADRException):
"""Exception raised if only one object was expected, but multiple were returned."""

detail = "get() returned more than one object."


class IntegrityError(ADRException):
"""Exception raised if there is a constraint violation while saving an object in the database."""

detail = "A database integrity check failed."
69 changes: 51 additions & 18 deletions src/ansys/dynamicreporting/core/serverless/adr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Iterable
import os
from pathlib import Path
import platform
import shutil
import sys
from typing import Any, Optional, Type, Union
import uuid
Expand Down Expand Up @@ -136,12 +138,29 @@ def _get_install_directory(self, ansys_installation: str) -> Path:
raise InvalidAnsysPath(f"Unable to detect an installation in: {','.join(dirs_to_check)}")

def _check_dir(self, dir_):
dir_path = Path(dir_)
dir_path = Path(dir_) if not isinstance(dir_, Path) else dir_
if not dir_path.exists() or not dir_path.is_dir():
self._logger.error(f"Invalid directory path: {dir_}")
raise InvalidPath(extra_detail=dir_)
return dir_path

def _migrate_db(self, db):
try: # upgrade databases
management.call_command("migrate", "--no-input", "--database", db, verbosity=0)
except Exception as e:
self._logger.error(f"{e}")
raise DatabaseMigrationError(extra_detail=str(e))
else:
from django.contrib.auth.models import Group, Permission, User

if not User.objects.using(db).filter(is_superuser=True).exists():
user = User.objects.using(db).create_superuser("nexus", "", "cei")
# include the nexus group (with all permissions)
nexus_group, created = Group.objects.using(db).get_or_create(name="nexus")
if created:
nexus_group.permissions.set(Permission.objects.using(db).all())
nexus_group.user_set.add(user)

def setup(self, collect_static: bool = False) -> None:
from django.conf import settings

Expand Down Expand Up @@ -207,22 +226,11 @@ def setup(self, collect_static: bool = False) -> None:
raise ImproperlyConfiguredError(extra_detail=str(e))

# migrations
if self._db_directory is not None:
try: # upgrades all databases
management.call_command("migrate", "--no-input", verbosity=0)
except Exception as e:
self._logger.error(f"{e}")
raise DatabaseMigrationError(extra_detail=str(e))
else:
from django.contrib.auth.models import Group, Permission, User

if not User.objects.filter(is_superuser=True).exists():
user = User.objects.create_superuser("nexus", "", "cei")
# include the nexus group (with all permissions)
nexus_group, created = Group.objects.get_or_create(name="nexus")
if created:
nexus_group.permissions.set(Permission.objects.all())
nexus_group.user_set.add(user)
if self._databases:
for db in self._databases:
self._migrate_db(db)
elif self._db_directory is not None:
self._migrate_db("default")

# geometry migration
try:
Expand Down Expand Up @@ -339,8 +347,33 @@ def query(
self,
query_type: Union[Session, Dataset, Type[Item], Type[Template]],
query: str = "",
**kwargs: Any,
) -> ObjectSet:
if not issubclass(query_type, (Item, Template, Session, Dataset)):
self._logger.error(f"{query_type} is not valid")
raise TypeError(f"{query_type} is not valid")
return query_type.find(query=query)
return query_type.find(query=query, **kwargs)

@staticmethod
def create_objects(
objects: Union[list, ObjectSet],
**kwargs: Any,
) -> int:
if not isinstance(objects, Iterable):
raise ADRException("objects must be an iterable")
count = 0
for obj in objects:
if kwargs.get("using", "default") != obj.db:
# required if copying across databases
obj.reinit()
obj.save(**kwargs)
count += 1
return count

def _is_sqlite(self, database: str) -> bool:
return "sqlite" in self._databases[database]["ENGINE"]

def _get_db_dir(self, database: str) -> str:
if self._is_sqlite(database):
return self._databases[database]["NAME"]
return ""
186 changes: 131 additions & 55 deletions src/ansys/dynamicreporting/core/serverless/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
ObjectDoesNotExist,
ValidationError,
)
from django.db import DatabaseError
from django.db import DataError
from django.db.models import Model, QuerySet
from django.db.models.base import subclass_exception
from django.db.models.manager import Manager
from django.db.utils import IntegrityError as DBIntegrityError

from ..exceptions import (
ADRException,
IntegrityError,
MultipleObjectsReturnedError,
ObjectDoesNotExistError,
ObjectNotSavedError,
Expand All @@ -42,7 +44,7 @@ def handle_field_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (FieldError, FieldDoesNotExist, ValidationError, DatabaseError) as e:
except (FieldError, FieldDoesNotExist, ValidationError, DataError) as e:
raise ADRException(extra_detail=f"One or more fields set or accessed are invalid: {e}")

return wrapper
Expand Down Expand Up @@ -97,6 +99,13 @@ def __new__(
parents,
namespace.get("__module__"),
)
add_exception_to_cls(
"IntegrityError",
IntegrityError,
new_cls,
parents,
namespace.get("__module__"),
)
# all classes must be dataclasses
new_cls = dataclass(eq=False, order=False, repr=False)(new_cls)
return new_cls
Expand All @@ -121,7 +130,7 @@ def __getattribute__(cls, name):


class BaseModel(metaclass=BaseMeta):
guid: UUID = field(init=False, compare=False, kw_only=True, default_factory=uuid.uuid1)
guid: UUID = field(compare=False, kw_only=True, default_factory=uuid.uuid1)
tags: str = field(compare=False, kw_only=True, default="")
_saved: bool = field(
init=False, compare=False, default=False
Expand Down Expand Up @@ -209,19 +218,108 @@ def _get_field_names(cls, with_types=False, include_private=False):
fields_.append((f.name, f.type) if with_types else f.name)
return tuple(fields_)

def _get_var_field_names(self, include_private=False):
fields_ = []
for f in vars(self).keys():
if not include_private and f.startswith("_"):
continue
fields_.append(f)
return tuple(fields_)

@classmethod
def _get_all_field_names(cls):
def _get_prop_field_names(cls):
"""Returns a list of all field names from a dataclass, including properties."""
property_fields = []
for name, value in inspect.getmembers(cls):
if isinstance(value, property):
property_fields.append(name)
return tuple(property_fields) + cls._get_field_names()
return tuple(property_fields)

@property
def saved(self):
def saved(self) -> bool:
return self._saved

@property
def _orm_saved(self) -> bool:
return not self._orm_instance._state.adding

@property
def _orm_db(self) -> str:
return self._orm_instance._state.db

@property
def db(self):
return self._orm_db

def as_dict(self):
out_dict = {}
# use a combination of vars and fields
cls_fields = set(self._get_field_names() + self._get_var_field_names())
for field_ in cls_fields:
if field_.startswith("_"):
continue
value = getattr(self, field_, None)
if value is None: # skip and use defaults
continue
out_dict[field_] = value
return out_dict

def _prepare_for_save(self, **kwargs):
self._saved = False

target_db = kwargs.pop("using", "default")
cls_fields = self._get_field_names() + self._get_prop_field_names()
model_fields = self._get_orm_field_names(self._orm_instance)
for field_ in cls_fields:
if field_ not in model_fields:
continue
value = getattr(self, field_, None)
if value is None: # skip and use defaults
continue
if isinstance(value, list):
objs = [o._orm_instance for o in value]
getattr(self._orm_instance, field_).add(*objs)
else:
if isinstance(value, BaseModel): # relations
try:
value = value._orm_instance.__class__.objects.using(target_db).get(
guid=value.guid
)
except ObjectDoesNotExist as e:
raise value.__class__.DoesNotExist(
extra_detail=f"Object with guid '{value.guid}'" f" does not exist: {e}"
)
# for all others
setattr(self._orm_instance, field_, value)

return self

def reinit(self):
self._orm_instance = self.__class__._orm_model_cls()

@handle_field_errors
def save(self, **kwargs):
try:
obj = self._prepare_for_save(**kwargs)
obj._orm_instance.save(**kwargs)
except DBIntegrityError as e:
raise self.__class__.IntegrityError(
extra_detail=f"Save failed for object with guid '{self.guid}': {e}"
)
except Exception as e:
raise e
else:
obj._saved = True

def delete(self, **kwargs):
if not self._saved:
raise self.__class__.NotSaved(
extra_detail=f"Delete failed for object with guid '{self.guid}'."
)
count, _ = self._orm_instance.delete(**kwargs)
self._saved = False
return count

@classmethod
def from_db(cls, orm_instance, parent=None):
cls_fields = dict(cls._get_field_names(with_types=True, include_private=True))
Expand Down Expand Up @@ -284,56 +382,21 @@ def from_db(cls, orm_instance, parent=None):
obj._saved = True
return obj

@handle_field_errors
def save(self, **kwargs):
self._saved = False # reset

cls_fields = self._get_all_field_names()
model_fields = self._get_orm_field_names(self._orm_instance)
for field_ in cls_fields:
if field_ not in model_fields:
continue
value = getattr(self, field_, None)
if value is None:
continue
if isinstance(value, list):
obj_list = []
for obj in value:
obj_list.append(obj._orm_instance)
getattr(self._orm_instance, field_).add(*obj_list)
else:
if isinstance(value, BaseModel): # relations
try:
value = value._orm_instance.__class__.objects.using(
kwargs.get("using", "default")
).get(guid=value.guid)
except ObjectDoesNotExist:
raise value.__class__.DoesNotExist
# for all others
setattr(self._orm_instance, field_, value)

self._orm_instance.save(**kwargs)
self._saved = True

@classmethod
@handle_field_errors
def create(cls, **kwargs):
target_db = kwargs.pop("using", "default")
obj = cls(**kwargs)
obj.save(force_insert=True)
obj.save(force_insert=True, using=target_db)
return obj

def delete(self, **kwargs):
if not self._saved:
raise self.__class__.NotSaved(extra_detail="Delete failed")
count, _ = self._orm_instance.delete(**kwargs)
self._saved = False
return count

@classmethod
@handle_field_errors
def get(cls, **kwargs):
try:
orm_instance = cls._orm_model_cls.objects.get(**kwargs)
orm_instance = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).get(
**kwargs
)
except ObjectDoesNotExist:
raise cls.DoesNotExist
except MultipleObjectsReturned:
Expand All @@ -343,21 +406,30 @@ def get(cls, **kwargs):

@classmethod
@handle_field_errors
def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)
def get_or_create(cls, **kwargs):
try:
return cls.get(**kwargs), False
except cls.DoesNotExist:
# Try to create an object using passed params.
try:
return cls.create(**kwargs), True
except cls.IntegrityError:
try:
return cls.get(**kwargs), False
except cls.DoesNotExist:
pass
raise

@classmethod
@handle_field_errors
def bulk_create(cls, **kwargs):
objs = cls._orm_model_cls.objects.bulk_create(**kwargs)
qs = cls._orm_model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

@classmethod
@handle_field_errors
def find(cls, query="", reverse=False, sort_tag="date"):
qs = cls._orm_model_cls.find(query=query, reverse=reverse, sort_tag=sort_tag)
def find(cls, query="", **kwargs):
qs = cls._orm_model_cls.find(query=query, **kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

def get_tags(self):
Expand Down Expand Up @@ -426,9 +498,13 @@ def saved(self):
return self._saved

def delete(self):
count = 0
for obj in self._obj_set:
obj.delete()
count += 1
self._orm_queryset.delete()
self._obj_set = []
self._saved = False
count, _ = self._orm_queryset.delete()
return count

def values_list(self, *fields, flat=False):
Expand Down
Loading
Loading