diff --git a/tests/conftest.py b/tests/conftest.py index ec975866..370abd71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from fastapi.testclient import TestClient from moto import mock_s3 from sqlalchemy import create_engine +from sqlalchemy.engine import Connection from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import create_database, database_exists, drop_database @@ -133,52 +134,42 @@ def test_db(scope="function"): drop_database(test_db_url) -@pytest.fixture -def data_db(scope="function"): - """ - Create a fresh test database for each test. - - This will populate the db using the alembic migrations. - Therefore it is slower but contains data. - - Note: use with `data_client` - - """ +@pytest.fixture(scope="session") +def data_db_connection() -> t.Generator[Connection, None, None]: test_db_url = get_test_db_url() - # Create the test database if database_exists(test_db_url): drop_database(test_db_url) create_database(test_db_url) - # Save DATABASE_URL - saved = os.environ["DATABASE_URL"] + + saved_db_url = os.environ["DATABASE_URL"] os.environ["DATABASE_URL"] = test_db_url - connection = None - test_session = None - try: - test_engine = create_engine(test_db_url) - connection = test_engine.connect() - run_migrations(test_engine) + test_engine = create_engine(test_db_url) - test_session_maker = sessionmaker( - autocommit=False, - autoflush=False, - bind=test_engine, - ) - test_session = test_session_maker() + run_migrations(test_engine) + connection = test_engine.connect() - # Run the tests - yield test_session - finally: - # restore DATABASE_URL - os.environ["DATABASE_URL"] = saved - if test_session is not None: - test_session.close() - if connection is not None: - connection.close() - # Drop the test database - drop_database(test_db_url) + yield connection + connection.close() + + os.environ["DATABASE_URL"] = saved_db_url + drop_database(test_db_url) + + +@pytest.fixture(scope="function") +def data_db(data_db_connection): + transaction = data_db_connection.begin() + + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=data_db_connection + ) + session = SessionLocal() + + yield session + + session.close() + transaction.rollback() @pytest.fixture diff --git a/tests/routes/test_document_families.py b/tests/routes/test_document_families.py index 11a2cf63..98fd8671 100644 --- a/tests/routes/test_document_families.py +++ b/tests/routes/test_document_families.py @@ -53,14 +53,7 @@ def test_physical_doc_languages_not_visible( data_db: Session, ): setup_with_two_docs(data_db) - # setup_with_multiple_docs( - # data_db, doc_data=TWO_DFC_ROW_ONE_LANGUAGE, event_data=TWO_EVENT_ROWS - # ) - data_db.execute( - update(PhysicalDocumentLanguage) - .where(PhysicalDocumentLanguage.document_id == 1) - .values(visible=False) - ) + data_db.execute(update(PhysicalDocumentLanguage).values(visible=False)) response = data_client.get( "/api/v1/documents/DocSlug1",