Skip to content

Commit

Permalink
Unify DAG creation/database cleaning fixtures for testing (#3361)
Browse files Browse the repository at this point in the history
* implement global clean_db function

* remove unused packages

* format lint

* hardcode TEST_POOL and DAG_PREFIX

* add comment and TODO for TEST_POOL and DAG_PREFIX constant

* rename fixture for format standards

* rename get_test_dag_id to sample_dag_id_fixture
  • Loading branch information
ngken0995 authored Dec 14, 2023
1 parent b7337c6 commit 80a75e6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 71 deletions.
35 changes: 35 additions & 0 deletions catalog/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from airflow.models import DagRun, Pool, TaskInstance
from airflow.utils.session import create_session


def pytest_addoption(parser):
Expand All @@ -24,3 +26,36 @@ def pytest_addoption(parser):
# Use this decorator on tests which are expected to take a long time and would best be
# run on CI only
mark_extended = pytest.mark.skipif("not config.getoption('extended')")


def _normalize_test_module_name(request) -> str:
# Extract the test name
name = request.module.__name__
# Replace periods with two underscores
return name.replace(".", "__")


@pytest.fixture
def sample_dag_id_fixture(request):
return f"{_normalize_test_module_name(request)}_dag"


@pytest.fixture
def sample_pool_fixture(request):
return f"{_normalize_test_module_name(request)}_pool"


@pytest.fixture
def clean_db(sample_dag_id_fixture, sample_pool_fixture):
with create_session() as session:
# synchronize_session='fetch' required here to refresh models
# https://stackoverflow.com/a/51222378 CC BY-SA 4.0
session.query(DagRun).filter(
DagRun.dag_id.startswith(sample_dag_id_fixture)
).delete(synchronize_session="fetch")
session.query(TaskInstance).filter(
TaskInstance.dag_id.startswith(sample_dag_id_fixture)
).delete(synchronize_session="fetch")
session.query(Pool).filter(Pool.pool.startswith(sample_pool_fixture)).delete(
synchronize_session="fetch"
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import pytest
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.models import DagBag, Pool
from airflow.models.dag import DAG
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
Expand All @@ -15,23 +14,14 @@

DEFAULT_DATE = datetime(2022, 1, 1)
TEST_TASK_ID = "wait_task"
TEST_POOL = "single_run_external_dags_sensor_test_pool"
DEV_NULL = "/dev/null"
DAG_PREFIX = "sreds" # single_run_external_dags_sensor


@pytest.fixture(autouse=True)
def clean_db():
with create_session() as session:
# synchronize_session='fetch' required here to refresh models
# https://stackoverflow.com/a/51222378 CC BY-SA 4.0
session.query(DagRun).filter(DagRun.dag_id.startswith(DAG_PREFIX)).delete(
synchronize_session="fetch"
)
session.query(TaskInstance).filter(
TaskInstance.dag_id.startswith(DAG_PREFIX)
).delete(synchronize_session="fetch")
session.query(Pool).filter(id == TEST_POOL).delete()
# unittest.TestCase only allow auto-use fixture which can't retrieve the declared fixtures on conftest.py
# TODO: TEST_POOL/DAG_PREFIX constants can be remove after unittest.TestCase are converted to pytest.
TEST_POOL = (
"catalog__tests__dags__common__sensors__test_single_run_external_dags_sensor_pool"
)
DAG_PREFIX = "catalog__tests__dags__common__sensors__test_single_run_external_dags_sensor_dag" # single_run_external_dags_sensor


def run_sensor(sensor):
Expand Down Expand Up @@ -75,6 +65,7 @@ def create_dagrun(dag, dag_state):
)


@pytest.mark.usefixtures("clean_db")
# This appears to be coming from Airflow internals during testing as a result of
# loading the example DAGs:
# /opt/airflow/.local/lib/python3.10/site-packages/airflow/example_dags/example_subdag_operator.py:43: RemovedInAirflow3Warning # noqa: E501
Expand Down
28 changes: 11 additions & 17 deletions catalog/tests/dags/common/sensors/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from datetime import timedelta

import pytest
from airflow.models import DagRun
from airflow.models.dag import DAG
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
Expand All @@ -15,14 +12,8 @@
TEST_DAG = DAG(TEST_DAG_ID, default_args={"owner": "airflow"})


@pytest.fixture(autouse=True)
def clean_db():
with create_session() as session:
session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete()


def _create_dagrun(start_date, conf={}):
return TEST_DAG.create_dagrun(
def _create_dagrun(start_date, sample_dag_id_fixture, conf={}):
return DAG(sample_dag_id_fixture, default_args={"owner": "airflow"}).create_dagrun(
start_date=start_date,
execution_date=start_date,
data_interval=(start_date, start_date),
Expand All @@ -32,14 +23,17 @@ def _create_dagrun(start_date, conf={}):
)


def test_get_most_recent_dag_run_returns_most_recent_execution_date():
def test_get_most_recent_dag_run_returns_most_recent_execution_date(
sample_dag_id_fixture, clean_db
):
most_recent = datetime(2023, 5, 10)
for i in range(3):
_create_dagrun(most_recent - timedelta(days=i))

assert get_most_recent_dag_run(TEST_DAG_ID) == most_recent
_create_dagrun(most_recent - timedelta(days=i), sample_dag_id_fixture)
assert get_most_recent_dag_run(sample_dag_id_fixture) == most_recent


def test_get_most_recent_dag_run_returns_empty_list_when_no_runs():
def test_get_most_recent_dag_run_returns_empty_list_when_no_runs(
sample_dag_id_fixture, clean_db
):
# Relies on ``clean_db`` cleaning up DagRuns from other tests
assert get_most_recent_dag_run(TEST_DAG_ID) == []
assert get_most_recent_dag_run(sample_dag_id_fixture) == []
23 changes: 4 additions & 19 deletions catalog/tests/dags/common/test_ingestion_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import pytest
import requests
from airflow.exceptions import AirflowSkipException
from airflow.models import DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
Expand All @@ -15,30 +13,17 @@


TEST_START_DATE = datetime(2022, 2, 1, 0, 0, 0)
TEST_DAG_ID = "api_healthcheck_test_dag"


@pytest.fixture(autouse=True)
def clean_db():
with create_session() as session:
# synchronize_session='fetch' required here to refresh models
# https://stackoverflow.com/a/51222378 CC BY-SA 4.0
session.query(DagRun).filter(DagRun.dag_id.startswith(TEST_DAG_ID)).delete(
synchronize_session="fetch"
)
session.query(TaskInstance).filter(
TaskInstance.dag_id.startswith(TEST_DAG_ID)
).delete(synchronize_session="fetch")


@pytest.fixture()
def index_readiness_dag():
def index_readiness_dag(sample_dag_id_fixture, clean_db):
# Create a DAG that just has an index_readiness_check task
with DAG(dag_id=TEST_DAG_ID, schedule=None, start_date=TEST_START_DATE) as dag:
with DAG(
dag_id=sample_dag_id_fixture, schedule=None, start_date=TEST_START_DATE
) as dag:
ingestion_server.index_readiness_check(
media_type="image", index_suffix="my_test_suffix", timeout=timedelta(days=1)
)

return dag


Expand Down
18 changes: 0 additions & 18 deletions catalog/tests/dags/providers/test_provider_dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from airflow import DAG
from airflow.exceptions import AirflowSkipException, BackfillUnfinished
from airflow.executors.debug_executor import DebugExecutor
from airflow.models import DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.utils.session import create_session
from pendulum import now

from catalog.tests.conftest import mark_extended
Expand All @@ -21,22 +19,6 @@
from providers.provider_workflows import ProviderWorkflow


DAG_ID = "test_provider_dag_factory"


def _clean_dag_from_db():
with create_session() as session:
session.query(DagRun).filter(DagRun.dag_id == DAG_ID).delete()
session.query(TaskInstance).filter(TaskInstance.dag_id == DAG_ID).delete()


@pytest.fixture()
def clean_db():
_clean_dag_from_db()
yield
_clean_dag_from_db()


@mark_extended
@pytest.mark.parametrize(
"side_effect",
Expand Down

0 comments on commit 80a75e6

Please sign in to comment.