Skip to content

Commit

Permalink
Add HookLineageCollector that during task execution
Browse files Browse the repository at this point in the history
should register and hold lineage sent from hooks.

Add HookLineageReader that defines whether HookLineageCollector
should be enabled to process lineage sent from hooks.

Add Dataset factories to make sure Datasets registered with
HookLineageCollector is AIP-60 compliant.

Signed-off-by: Jakub Dardzinski <[email protected]>
  • Loading branch information
JDarDagran committed Jun 19, 2024
1 parent d048122 commit 1293a07
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 5 deletions.
5 changes: 5 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def normalize_noop(parts: SplitResult) -> SplitResult:
return parts


def create_dataset(uri: str) -> Dataset:
"""Create a dataset object from a dataset URI."""
return Dataset(uri=uri)


def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
if scheme == "file":
return normalize_noop
Expand Down
140 changes: 140 additions & 0 deletions airflow/lineage/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Union

import attr

from airflow.datasets import Dataset, create_dataset
from airflow.hooks.base import BaseHook
from airflow.io.store import ObjectStore
from airflow.providers_manager import ProvidersManager
from airflow.utils.log.logging_mixin import LoggingMixin

# Store context what sent lineage.
LineageContext = Union[BaseHook, ObjectStore]

_hook_lineage_collector: HookLineageCollector | None = None


@attr.define
class HookLineage:
"""Holds lineage collected by HookLineageCollector."""

inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)
outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)


class HookLineageCollector(LoggingMixin):
"""
HookLineageCollector is a base class for collecting hook lineage information.
It is used to collect the input and output datasets of a hook execution.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inputs: list[tuple[Dataset, LineageContext]] = []
self.outputs: list[tuple[Dataset, LineageContext]] = []

@staticmethod
def create_dataset(dataset_kwargs: dict) -> Dataset:
"""Create a Dataset instance from the given dataset kwargs."""
if "uri" in dataset_kwargs:
# Fallback to default factory using the provided URI
return create_dataset(dataset_kwargs["uri"])

scheme: str = dataset_kwargs.pop("scheme", None)
if not scheme:
raise ValueError(
"Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset."
)

dataset_factory = ProvidersManager().dataset_factories.get(scheme)
if not dataset_factory:
raise ValueError(
f"Unsupported scheme: '{scheme}'. Please provide a valid URI to create a Dataset."
)

return dataset_factory(**dataset_kwargs)

def add_input_dataset(self, dataset_kwargs: dict, hook: LineageContext):
"""Add the input dataset and its corresponding hook execution context to the collector."""
dataset = self.create_dataset(dataset_kwargs)
self.inputs.append((dataset, hook))

def add_output_dataset(self, dataset_kwargs: dict, hook: LineageContext):
"""Add the output dataset and its corresponding hook execution context to the collector."""
dataset = self.create_dataset(dataset_kwargs)
self.outputs.append((dataset, hook))

@property
def collected_datasets(self) -> HookLineage:
"""Get the collected hook lineage information."""
return HookLineage(self.inputs, self.outputs)

@property
def has_collected(self) -> bool:
"""Check if any datasets have been collected."""
return len(self.inputs) != 0 or len(self.outputs) != 0


class NoOpCollector(HookLineageCollector):
"""
NoOpCollector is a hook lineage collector that does nothing.
It is used when you want to disable lineage collection.
"""

def add_input_dataset(self, *_):
pass

def add_output_dataset(self, *_):
pass

@property
def collected_datasets(
self,
) -> HookLineage:
self.log.warning("You should not call this as there's no reader.")
return HookLineage([], [])


class HookLineageReader(LoggingMixin):
"""Class used to retrieve the hook lineage information collected by HookLineageCollector."""

def __init__(self, **kwargs):
self.lineage_collector = get_hook_lineage_collector()

def retrieve_hook_lineage(self) -> HookLineage:
"""Retrieve hook lineage from HookLineageCollector."""
hook_lineage = self.lineage_collector.collected_datasets
return hook_lineage


def get_hook_lineage_collector() -> HookLineageCollector:
"""Get singleton lineage collector."""
global _hook_lineage_collector
if not _hook_lineage_collector:
# is there a better why how to use noop?
if ProvidersManager().hook_lineage_readers:
_hook_lineage_collector = HookLineageCollector()
else:
_hook_lineage_collector = NoOpCollector()
return _hook_lineage_collector
11 changes: 11 additions & 0 deletions airflow/provider.yaml.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,21 @@
"handler": {
"type": ["string", "null"],
"description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op."
},
"factory": {
"type": ["string", "null"],
"description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset."
}
}
}
},
"hook-lineage-readers": {
"type": "array",
"description": "Hook lineage readers",
"items": {
"type": "string"
}
},
"transfers": {
"type": "array",
"items": {
Expand Down
43 changes: 38 additions & 5 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def ensure_prefix(field):
if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.datasets import Dataset
from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
from airflow.typing_compat import Literal
Expand Down Expand Up @@ -426,6 +427,8 @@ def __init__(self):
self._hooks_dict: dict[str, HookInfo] = {}
self._fs_set: set[str] = set()
self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {}
self._dataset_factories: dict[str, Callable[..., Dataset]] = {}
self._hook_lineage_readers: set[str] = set()
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment]
# keeps mapping between connection_types and hook class, package they come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
Expand Down Expand Up @@ -526,8 +529,15 @@ def initialize_providers_filesystems(self):
def initialize_providers_dataset_uri_handlers(self):
"""Lazy initialization of provider dataset URI handlers."""
self.initialize_providers_list()
self._discover_dataset_uri_handlers()
self._discover_dataset_uri_handlers_and_factories()

@provider_info_cache("hook_lineage_readers")
def initialize_providers_hook_lineage_readers(self):
"""Lazy initialization of providers hook lineage readers."""
self.initialize_providers_list()
self._discover_hook_lineage_readers()

@provider_info_cache("hook_lineage_writers")
@provider_info_cache("taskflow_decorators")
def initialize_providers_taskflow_decorator(self):
"""Lazy initialization of providers hooks."""
Expand Down Expand Up @@ -564,7 +574,7 @@ def initialize_providers_notifications(self):
self.initialize_providers_list()
self._discover_notifications()

@provider_info_cache("auth_managers")
@provider_info_cache(cache_name="auth_managers")
def initialize_providers_auth_managers(self):
"""Lazy initialization of providers notifications information."""
self.initialize_providers_list()
Expand Down Expand Up @@ -878,21 +888,34 @@ def _discover_filesystems(self) -> None:
self._fs_set.add(fs_module_name)
self._fs_set = set(sorted(self._fs_set))

def _discover_dataset_uri_handlers(self) -> None:
from airflow.datasets import normalize_noop
def _discover_dataset_uri_handlers_and_factories(self) -> None:
from airflow.datasets import create_dataset, normalize_noop

for provider_package, provider in self._provider_dict.items():
for handler_info in provider.data.get("dataset-uris", []):
try:
schemes = handler_info["schemes"]
handler_path = handler_info["handler"]
factory_path = handler_info["factory"]
except KeyError:
continue
if handler_path is None:
handler = normalize_noop
elif not (handler := _correctness_check(provider_package, handler_path, provider)):
if factory_path is None:
factory = create_dataset
elif not (handler := _correctness_check(provider_package, handler_path, provider)) or not (
factory := _correctness_check(provider_package, factory_path, provider)
):
continue
self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes)
self._dataset_factories.update((scheme, factory) for scheme in schemes)

def _discover_hook_lineage_readers(self) -> None:
for provider_package, provider in self._provider_dict.items():
for hook_lineage_reader in provider.data.get("hook-lineage-readers", []):
if _correctness_check(provider_package, hook_lineage_reader, provider):
self._hook_lineage_readers.add(hook_lineage_reader)
self._fs_set = set(sorted(self._fs_set))

def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
Expand Down Expand Up @@ -1289,11 +1312,21 @@ def filesystem_module_names(self) -> list[str]:
self.initialize_providers_filesystems()
return sorted(self._fs_set)

@property
def dataset_factories(self) -> dict[str, Callable[..., Dataset]]:
self.initialize_providers_dataset_uri_handlers()
return self._dataset_factories

@property
def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]:
self.initialize_providers_dataset_uri_handlers()
return self._dataset_uri_handlers

@property
def hook_lineage_readers(self) -> list[str]:
self.initialize_providers_hook_lineage_readers()
return sorted(self._hook_lineage_readers)

@property
def provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
self.initialize_providers_configuration()
Expand Down
108 changes: 108 additions & 0 deletions tests/lineage/test_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from airflow.datasets import Dataset
from airflow.hooks.base import BaseHook
from airflow.lineage import hook
from airflow.lineage.hook import HookLineage, HookLineageCollector, NoOpCollector, get_hook_lineage_collector


class TestHookLineageCollector:
def test_are_datasets_collected(self):
lineage_collector = HookLineageCollector()
assert lineage_collector is not None
assert lineage_collector.collected_datasets == HookLineage()
input_hook = BaseHook()
output_hook = BaseHook()
lineage_collector.add_input_dataset({"uri": "s3://in_bucket/file"}, input_hook)
lineage_collector.add_output_dataset(
{"uri": "postgres://example.com:5432/database/default/table"}, output_hook
)
assert lineage_collector.collected_datasets == HookLineage(
[(Dataset("s3://in_bucket/file"), input_hook)],
[(Dataset("postgres://example.com:5432/database/default/table"), output_hook)],
)

@patch("airflow.lineage.hook.create_dataset")
def test_add_input_dataset(self, mock_create_dataset):
collector = HookLineageCollector()
mock_dataset = MagicMock(spec=Dataset)
mock_create_dataset.return_value = mock_dataset

dataset_kwargs = {"uri": "test_uri"}
hook = MagicMock()
collector.add_input_dataset(dataset_kwargs, hook)

assert collector.inputs == [(mock_dataset, hook)]
mock_create_dataset.assert_called_once_with("test_uri")

@patch("airflow.lineage.hook.ProvidersManager")
def test_create_dataset(self, mock_providers_manager):
def create_dataset(arg1, arg2="default"):
return Dataset(uri=f"myscheme://{arg1}/{arg2}")

mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset}
collector = HookLineageCollector()
assert collector.create_dataset({"scheme": "myscheme", "arg1": "value_1"}) == Dataset(
"myscheme://value_1/default"
)
assert collector.create_dataset(
{"scheme": "myscheme", "arg1": "value_1", "arg2": "value_2"}
) == Dataset("myscheme://value_1/value_2")

def test_collected_datasets(self):
collector = HookLineageCollector()
inputs = [(MagicMock(spec=Dataset), MagicMock())]
outputs = [(MagicMock(spec=Dataset), MagicMock())]
collector.inputs = inputs
collector.outputs = outputs

hook_lineage = collector.collected_datasets
assert hook_lineage.inputs == inputs
assert hook_lineage.outputs == outputs

def test_has_collected(self):
collector = HookLineageCollector()
assert not collector.has_collected

collector.inputs = [MagicMock(spec=Dataset), MagicMock()]
assert collector.has_collected


@pytest.mark.parametrize(
"has_readers, expected_class",
[
(True, HookLineageCollector),
(False, NoOpCollector),
],
)
@patch("airflow.lineage.hook.ProvidersManager")
def test_get_hook_lineage_collector(mock_providers_manager, has_readers, expected_class):
# reset global variable
hook._hook_lineage_collector = None
if has_readers:
mock_providers_manager.return_value.hook_lineage_readers = [MagicMock()]
else:
mock_providers_manager.return_value.hook_lineage_readers = []
assert isinstance(get_hook_lineage_collector(), expected_class)
assert get_hook_lineage_collector() is get_hook_lineage_collector()

0 comments on commit 1293a07

Please sign in to comment.