Skip to content

Commit

Permalink
Merge pull request #411 from hackforla/refactor-data-access
Browse files Browse the repository at this point in the history
refactor DataAccessLayer class
  • Loading branch information
tylerthome authored Jun 29, 2022
2 parents 17c8e50 + 9d8e9f8 commit 2bc5771
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
2 changes: 2 additions & 0 deletions api/openapi_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os import environ as env

from openapi_server import encoder
from openapi_server.models.database import DataAccessLayer
from openapi_server.exceptions import AuthError, handle_auth_error
from dotenv import load_dotenv, find_dotenv

Expand All @@ -13,6 +14,7 @@
load_dotenv(ENV_FILE)
SECRET_KEY=env.get('SECRET_KEY')

DataAccessLayer.db_init()

def main():
app = connexion.App(__name__, specification_dir='./_spec/')
Expand Down
13 changes: 6 additions & 7 deletions api/openapi_server/controllers/service_provider_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from openapi_server.models import database as db
from sqlalchemy.orm import Session

dal = db.DataAccessLayer()
dal.db_init()
db_engine = db.DataAccessLayer.get_engine()

def create_service_provider(): # noqa: E501
"""Create a housing program service provider
Expand All @@ -27,7 +26,7 @@ def create_service_provider(): # noqa: E501
connexion.request.get_json()).to_dict()
except ValueError:
return traceback.format_exc(ValueError), 400
with Session(dal.engine) as session:
with Session(db_engine) as session:
row = db.HousingProgramServiceProvider(
provider_name=provider["provider_name"]
)
Expand All @@ -51,7 +50,7 @@ def delete_service_provider(provider_id): # noqa: E501
:rtype: None
"""
with Session(dal.engine) as session:
with Session(db_engine) as session:
query = session.query(
db.HousingProgramServiceProvider).filter(
db.HousingProgramServiceProvider.id == provider_id)
Expand All @@ -71,7 +70,7 @@ def get_service_provider_by_id(provider_id): # noqa: E501
:rtype: ServiceProviderWithId
"""
with Session(dal.engine) as session:
with Session(db_engine) as session:
row = session.get(
db.HousingProgramServiceProvider, provider_id)
if row != None:
Expand All @@ -93,7 +92,7 @@ def get_service_providers(): # noqa: E501
:rtype: List[ServiceProviderWithId]
"""
resp = []
with Session(dal.engine) as session:
with Session(db_engine) as session:
table = session.query(db.HousingProgramServiceProvider).all()
for row in table:
provider = ServiceProvider(
Expand Down Expand Up @@ -122,7 +121,7 @@ def update_service_provider(provider_id): # noqa: E501
connexion.request.get_json()).to_dict()
except ValueError:
return traceback.format_exc(ValueError), 400
with Session(dal.engine) as session:
with Session(db_engine) as session:
query = session.query(
db.HousingProgramServiceProvider).filter(
db.HousingProgramServiceProvider.id == provider_id)
Expand Down
24 changes: 15 additions & 9 deletions api/openapi_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,22 @@ class ProgramCaseStatusLog(Base):
src_status = Column(Integer, ForeignKey('case_status.id'), nullable=False)
dest_status = Column(Integer, ForeignKey('case_status.id'), nullable=False)



class DataAccessLayer:
connection = None
engine = None
_engine = None

# temporary local sqlite DB, replace with conn str for postgres container port for real e2e
conn_string = "sqlite:///./homeuniteus.db"
_conn_string = "sqlite:///./homeuniteus.db"

def db_init(self, conn_string=None):
self.engine = create_engine(conn_string or self.conn_string, echo=True, future=True)
Base.metadata.create_all(bind=self.engine)
self.connection = self.engine.connect()
@classmethod
def db_init(cls, conn_string=None):
Base.metadata.create_all(bind=cls.get_engine(conn_string))

@classmethod
def connect(cls):
return cls.get_engine().connect()

@classmethod
def get_engine(cls, conn_string=None):
if cls._engine == None:
cls._engine = create_engine(conn_string or cls._conn_string, echo=True, future=True)
return cls._engine

0 comments on commit 2bc5771

Please sign in to comment.