Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite how DAG to dataset / dataset alias are stored #42055

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,8 +3236,6 @@ def bulk_write_to_db(
if not dags:
return

from airflow.models.dataset import DagScheduleDatasetAliasReference

log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

Expand Down Expand Up @@ -3344,18 +3342,19 @@ def bulk_write_to_db(

from airflow.datasets import Dataset
from airflow.models.dataset import (
DagScheduleDatasetAliasReference,
DagScheduleDatasetReference,
DatasetModel,
TaskOutletDatasetReference,
)

dag_references: dict[str, set[Dataset | DatasetAlias]] = defaultdict(set)
dag_references: dict[str, set[tuple[Literal["dataset", "dataset-alias"], str]]] = defaultdict(set)
outlet_references = defaultdict(set)
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
outlet_dataset_models: dict[DatasetModel, None] = {}
input_dataset_models: dict[DatasetModel, None] = {}
outlet_dataset_alias_models: set[DatasetAliasModel] = set()
input_dataset_aliases: set[DatasetAliasModel] = set()
input_dataset_alias_models: set[DatasetAliasModel] = set()

# here we go through dags and tasks to check for dataset references
# if there are now None and previously there were some, we delete them
Expand All @@ -3371,12 +3370,12 @@ def bulk_write_to_db(
curr_orm_dag.schedule_dataset_alias_references = []
else:
for _, dataset in dataset_condition.iter_datasets():
dag_references[dag.dag_id].add(Dataset(uri=dataset.uri))
input_datasets[DatasetModel.from_public(dataset)] = None
dag_references[dag.dag_id].add(("dataset", dataset.uri))
input_dataset_models[DatasetModel.from_public(dataset)] = None

for dataset_alias in dataset_condition.iter_dataset_aliases():
dag_references[dag.dag_id].add(dataset_alias)
input_dataset_aliases.add(DatasetAliasModel.from_public(dataset_alias))
dag_references[dag.dag_id].add(("dataset-alias", dataset_alias.name))
input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias))

curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
Expand All @@ -3399,63 +3398,70 @@ def bulk_write_to_db(
curr_outlet_references.remove(ref)

for d in dataset_outlets:
outlet_dataset_models[DatasetModel.from_public(d)] = None
outlet_references[(task.dag_id, task.task_id)].add(d.uri)
outlet_datasets[DatasetModel.from_public(d)] = None

for d_a in dataset_alias_outlets:
outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a))

all_datasets = outlet_datasets
all_datasets.update(input_datasets)
all_dataset_models = outlet_dataset_models
all_dataset_models.update(input_dataset_models)

# store datasets
stored_datasets: dict[str, DatasetModel] = {}
new_datasets: list[DatasetModel] = []
for dataset in all_datasets:
stored_dataset = session.scalar(
stored_dataset_models: dict[str, DatasetModel] = {}
new_dataset_models: list[DatasetModel] = []
for dataset in all_dataset_models:
stored_dataset_model = session.scalar(
select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1)
)
if stored_dataset:
if stored_dataset_model:
# Some datasets may have been previously unreferenced, and therefore orphaned by the
# scheduler. But if we're here, then we have found that dataset again in our DAGs, which
# means that it is no longer an orphan, so set is_orphaned to False.
stored_dataset.is_orphaned = expression.false()
stored_datasets[stored_dataset.uri] = stored_dataset
stored_dataset_model.is_orphaned = expression.false()
stored_dataset_models[stored_dataset_model.uri] = stored_dataset_model
else:
new_datasets.append(dataset)
dataset_manager.create_datasets(dataset_models=new_datasets, session=session)
stored_datasets.update({dataset.uri: dataset for dataset in new_datasets})
new_dataset_models.append(dataset)
dataset_manager.create_datasets(dataset_models=new_dataset_models, session=session)
stored_dataset_models.update(
{dataset_model.uri: dataset_model for dataset_model in new_dataset_models}
)

del new_datasets
del all_datasets
del new_dataset_models
del all_dataset_models

# store dataset aliases
all_datasets_alias_models = input_dataset_aliases | outlet_dataset_alias_models
stored_dataset_aliases: dict[str, DatasetAliasModel] = {}
all_datasets_alias_models = input_dataset_alias_models | outlet_dataset_alias_models
stored_dataset_alias_models: dict[str, DatasetAliasModel] = {}
new_dataset_alias_models: set[DatasetAliasModel] = set()
if all_datasets_alias_models:
all_dataset_alias_names = {dataset_alias.name for dataset_alias in all_datasets_alias_models}
all_dataset_alias_names = {
dataset_alias_model.name for dataset_alias_model in all_datasets_alias_models
}

stored_dataset_aliases = {
stored_dataset_alias_models = {
dsa_m.name: dsa_m
for dsa_m in session.scalars(
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names))
).fetchall()
}

if stored_dataset_aliases:
if stored_dataset_alias_models:
new_dataset_alias_models = {
dataset_alias_model
for dataset_alias_model in all_datasets_alias_models
if dataset_alias_model.name not in stored_dataset_aliases.keys()
if dataset_alias_model.name not in stored_dataset_alias_models.keys()
}
else:
new_dataset_alias_models = all_datasets_alias_models

session.add_all(new_dataset_alias_models)
session.flush()
stored_dataset_aliases.update(
{dataset_alias.name: dataset_alias for dataset_alias in new_dataset_alias_models}
stored_dataset_alias_models.update(
{
dataset_alias_model.name: dataset_alias_model
for dataset_alias_model in new_dataset_alias_models
}
)

del new_dataset_alias_models
Expand All @@ -3464,14 +3470,18 @@ def bulk_write_to_db(
# reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references
for dag_id, base_dataset_list in dag_references.items():
dag_refs_needed = {
DagScheduleDatasetReference(dataset_id=stored_datasets[base_dataset.uri].id, dag_id=dag_id)
if isinstance(base_dataset, Dataset)
DagScheduleDatasetReference(
dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id
)
if base_dataset_type == "dataset"
else DagScheduleDatasetAliasReference(
alias_id=stored_dataset_aliases[base_dataset.name].id, dag_id=dag_id
alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id
)
for base_dataset in base_dataset_list
for base_dataset_type, base_dataset_identifier in base_dataset_list
}

# if isinstance(base_dataset, Dataset)

dag_refs_stored = (
set(existing_dags.get(dag_id).schedule_dataset_references) # type: ignore
| set(existing_dags.get(dag_id).schedule_dataset_alias_references) # type: ignore
Expand All @@ -3491,7 +3501,9 @@ def bulk_write_to_db(
# reconcile task-outlet-dataset references
for (dag_id, task_id), uri_list in outlet_references.items():
task_refs_needed = {
TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id)
TaskOutletDatasetReference(
dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, task_id=task_id
)
for uri in uri_list
}
task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)]
Expand Down