-
Notifications
You must be signed in to change notification settings - Fork 14.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HookLineageCollector that during task execution
should register and hold lineage sent from hooks. Add HookLineageReader that defines whether HookLineageCollector should be enabled to process lineage sent from hooks. Add Dataset factories to make sure Datasets registered with HookLineageCollector is AIP-60 compliant. Signed-off-by: Jakub Dardzinski <[email protected]>
- Loading branch information
1 parent
d048122
commit 1293a07
Showing
5 changed files
with
302 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from typing import Union | ||
|
||
import attr | ||
|
||
from airflow.datasets import Dataset, create_dataset | ||
from airflow.hooks.base import BaseHook | ||
from airflow.io.store import ObjectStore | ||
from airflow.providers_manager import ProvidersManager | ||
from airflow.utils.log.logging_mixin import LoggingMixin | ||
|
||
# Store context what sent lineage. | ||
LineageContext = Union[BaseHook, ObjectStore] | ||
|
||
_hook_lineage_collector: HookLineageCollector | None = None | ||
|
||
|
||
@attr.define | ||
class HookLineage: | ||
"""Holds lineage collected by HookLineageCollector.""" | ||
|
||
inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) | ||
outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) | ||
|
||
|
||
class HookLineageCollector(LoggingMixin): | ||
""" | ||
HookLineageCollector is a base class for collecting hook lineage information. | ||
It is used to collect the input and output datasets of a hook execution. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.inputs: list[tuple[Dataset, LineageContext]] = [] | ||
self.outputs: list[tuple[Dataset, LineageContext]] = [] | ||
|
||
@staticmethod | ||
def create_dataset(dataset_kwargs: dict) -> Dataset: | ||
"""Create a Dataset instance from the given dataset kwargs.""" | ||
if "uri" in dataset_kwargs: | ||
# Fallback to default factory using the provided URI | ||
return create_dataset(dataset_kwargs["uri"]) | ||
|
||
scheme: str = dataset_kwargs.pop("scheme", None) | ||
if not scheme: | ||
raise ValueError( | ||
"Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." | ||
) | ||
|
||
dataset_factory = ProvidersManager().dataset_factories.get(scheme) | ||
if not dataset_factory: | ||
raise ValueError( | ||
f"Unsupported scheme: '{scheme}'. Please provide a valid URI to create a Dataset." | ||
) | ||
|
||
return dataset_factory(**dataset_kwargs) | ||
|
||
def add_input_dataset(self, dataset_kwargs: dict, hook: LineageContext): | ||
"""Add the input dataset and its corresponding hook execution context to the collector.""" | ||
dataset = self.create_dataset(dataset_kwargs) | ||
self.inputs.append((dataset, hook)) | ||
|
||
def add_output_dataset(self, dataset_kwargs: dict, hook: LineageContext): | ||
"""Add the output dataset and its corresponding hook execution context to the collector.""" | ||
dataset = self.create_dataset(dataset_kwargs) | ||
self.outputs.append((dataset, hook)) | ||
|
||
@property | ||
def collected_datasets(self) -> HookLineage: | ||
"""Get the collected hook lineage information.""" | ||
return HookLineage(self.inputs, self.outputs) | ||
|
||
@property | ||
def has_collected(self) -> bool: | ||
"""Check if any datasets have been collected.""" | ||
return len(self.inputs) != 0 or len(self.outputs) != 0 | ||
|
||
|
||
class NoOpCollector(HookLineageCollector): | ||
""" | ||
NoOpCollector is a hook lineage collector that does nothing. | ||
It is used when you want to disable lineage collection. | ||
""" | ||
|
||
def add_input_dataset(self, *_): | ||
pass | ||
|
||
def add_output_dataset(self, *_): | ||
pass | ||
|
||
@property | ||
def collected_datasets( | ||
self, | ||
) -> HookLineage: | ||
self.log.warning("You should not call this as there's no reader.") | ||
return HookLineage([], []) | ||
|
||
|
||
class HookLineageReader(LoggingMixin): | ||
"""Class used to retrieve the hook lineage information collected by HookLineageCollector.""" | ||
|
||
def __init__(self, **kwargs): | ||
self.lineage_collector = get_hook_lineage_collector() | ||
|
||
def retrieve_hook_lineage(self) -> HookLineage: | ||
"""Retrieve hook lineage from HookLineageCollector.""" | ||
hook_lineage = self.lineage_collector.collected_datasets | ||
return hook_lineage | ||
|
||
|
||
def get_hook_lineage_collector() -> HookLineageCollector: | ||
"""Get singleton lineage collector.""" | ||
global _hook_lineage_collector | ||
if not _hook_lineage_collector: | ||
# is there a better why how to use noop? | ||
if ProvidersManager().hook_lineage_readers: | ||
_hook_lineage_collector = HookLineageCollector() | ||
else: | ||
_hook_lineage_collector = NoOpCollector() | ||
return _hook_lineage_collector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from airflow.datasets import Dataset | ||
from airflow.hooks.base import BaseHook | ||
from airflow.lineage import hook | ||
from airflow.lineage.hook import HookLineage, HookLineageCollector, NoOpCollector, get_hook_lineage_collector | ||
|
||
|
||
class TestHookLineageCollector: | ||
def test_are_datasets_collected(self): | ||
lineage_collector = HookLineageCollector() | ||
assert lineage_collector is not None | ||
assert lineage_collector.collected_datasets == HookLineage() | ||
input_hook = BaseHook() | ||
output_hook = BaseHook() | ||
lineage_collector.add_input_dataset({"uri": "s3://in_bucket/file"}, input_hook) | ||
lineage_collector.add_output_dataset( | ||
{"uri": "postgres://example.com:5432/database/default/table"}, output_hook | ||
) | ||
assert lineage_collector.collected_datasets == HookLineage( | ||
[(Dataset("s3://in_bucket/file"), input_hook)], | ||
[(Dataset("postgres://example.com:5432/database/default/table"), output_hook)], | ||
) | ||
|
||
@patch("airflow.lineage.hook.create_dataset") | ||
def test_add_input_dataset(self, mock_create_dataset): | ||
collector = HookLineageCollector() | ||
mock_dataset = MagicMock(spec=Dataset) | ||
mock_create_dataset.return_value = mock_dataset | ||
|
||
dataset_kwargs = {"uri": "test_uri"} | ||
hook = MagicMock() | ||
collector.add_input_dataset(dataset_kwargs, hook) | ||
|
||
assert collector.inputs == [(mock_dataset, hook)] | ||
mock_create_dataset.assert_called_once_with("test_uri") | ||
|
||
@patch("airflow.lineage.hook.ProvidersManager") | ||
def test_create_dataset(self, mock_providers_manager): | ||
def create_dataset(arg1, arg2="default"): | ||
return Dataset(uri=f"myscheme://{arg1}/{arg2}") | ||
|
||
mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset} | ||
collector = HookLineageCollector() | ||
assert collector.create_dataset({"scheme": "myscheme", "arg1": "value_1"}) == Dataset( | ||
"myscheme://value_1/default" | ||
) | ||
assert collector.create_dataset( | ||
{"scheme": "myscheme", "arg1": "value_1", "arg2": "value_2"} | ||
) == Dataset("myscheme://value_1/value_2") | ||
|
||
def test_collected_datasets(self): | ||
collector = HookLineageCollector() | ||
inputs = [(MagicMock(spec=Dataset), MagicMock())] | ||
outputs = [(MagicMock(spec=Dataset), MagicMock())] | ||
collector.inputs = inputs | ||
collector.outputs = outputs | ||
|
||
hook_lineage = collector.collected_datasets | ||
assert hook_lineage.inputs == inputs | ||
assert hook_lineage.outputs == outputs | ||
|
||
def test_has_collected(self): | ||
collector = HookLineageCollector() | ||
assert not collector.has_collected | ||
|
||
collector.inputs = [MagicMock(spec=Dataset), MagicMock()] | ||
assert collector.has_collected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"has_readers, expected_class", | ||
[ | ||
(True, HookLineageCollector), | ||
(False, NoOpCollector), | ||
], | ||
) | ||
@patch("airflow.lineage.hook.ProvidersManager") | ||
def test_get_hook_lineage_collector(mock_providers_manager, has_readers, expected_class): | ||
# reset global variable | ||
hook._hook_lineage_collector = None | ||
if has_readers: | ||
mock_providers_manager.return_value.hook_lineage_readers = [MagicMock()] | ||
else: | ||
mock_providers_manager.return_value.hook_lineage_readers = [] | ||
assert isinstance(get_hook_lineage_collector(), expected_class) | ||
assert get_hook_lineage_collector() is get_hook_lineage_collector() |