Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage, postgres: add OpenLineage support for Postgres #31617

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from contextlib import closing
from copy import deepcopy
from typing import Any, Iterable, Union
from typing import TYPE_CHECKING, Any, Iterable, Union

import psycopg2
import psycopg2.extensions
Expand All @@ -33,6 +33,9 @@
from airflow.models.connection import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import DatabaseInfo

CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor]


Expand Down Expand Up @@ -317,3 +320,46 @@ def _generate_insert_sql(
sql += f"{on_conflict_str} DO NOTHING"

return sql

def get_openlineage_database_info(self, connection) -> DatabaseInfo:
"""Returns Postgres/Redshift specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

is_redshift = connection.extra_dejson.get("redshift", False)

if is_redshift:
authority = self._get_openlineage_redshift_authority_part(connection)
else:
authority = DbApiHook.get_openlineage_authority_part(connection, default_port=5432)

return DatabaseInfo(
scheme="postgres" if not is_redshift else "redshift",
authority=authority,
database=self.database or connection.schema,
)

def _get_openlineage_redshift_authority_part(self, connection) -> str:
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException

raise AirflowException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)
aws_conn_id = connection.extra_dejson.get("aws_conn_id", "aws_default")

port = connection.port or 5439
cluster_identifier = connection.extra_dejson.get("cluster-identifier", connection.host.split(".")[0])
region_name = AwsBaseHook(aws_conn_id=aws_conn_id).region_name

return f"{cluster_identifier}.{region_name}:{port}"

def get_openlineage_database_dialect(self, connection) -> str:
"""Returns postgres/redshift dialect."""
return "redshift" if connection.extra_dejson.get("redshift", False) else "postgres"

def get_openlineage_default_schema(self) -> str | None:
"""Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
return self.get_first("SELECT CURRENT_SCHEMA;")[0]
8 changes: 4 additions & 4 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"tests/providers/postgres/file.py",
),
{
"affected-providers-list-as-string": "amazon common.sql google postgres",
"affected-providers-list-as-string": "amazon common.sql google openlineage postgres",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
Expand All @@ -110,7 +110,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"docs-build": "true",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "Providers[amazon] "
"API Always Providers[common.sql,postgres] Providers[google]",
"API Always Providers[common.sql,openlineage,postgres] Providers[google]",
},
id="API and providers tests and docs should run",
)
Expand Down Expand Up @@ -164,7 +164,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"tests/providers/postgres/file.py",
),
{
"affected-providers-list-as-string": "amazon common.sql google postgres",
"affected-providers-list-as-string": "amazon common.sql google openlineage postgres",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
Expand All @@ -177,7 +177,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"run-kubernetes-tests": "true",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "Providers[amazon] "
"Always Providers[common.sql,postgres] Providers[google]",
"Always Providers[common.sql,openlineage,postgres] Providers[google]",
},
id="Helm tests, providers (both upstream and downstream),"
"kubernetes tests and docs should run",
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,8 @@
],
"cross-providers-deps": [
"amazon",
"common.sql"
"common.sql",
"openlineage"
],
"excluded-python-versions": []
},
Expand Down
49 changes: 49 additions & 0 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,55 @@ def test_schema_kwarg_database_kwarg_compatibility(self):
hook = PostgresHook(schema=database)
assert hook.database == database

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
"host,conn_cluster_identifier,expected_host",
[
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"cluster-identifier.us-east-1",
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-identifier",
"different-identifier.us-east-1",
),
],
)
def test_openlineage_methods_with_redshift(
self,
mock_aws_hook_class,
aws_conn_id,
port,
host,
conn_cluster_identifier,
expected_host,
):
mock_conn_extra = {
"iam": True,
"redshift": True,
}
if aws_conn_id is not NOTSET:
mock_conn_extra["aws_conn_id"] = aws_conn_id
if conn_cluster_identifier is not NOTSET:
mock_conn_extra["cluster-identifier"] = conn_cluster_identifier

self.connection.extra = json.dumps(mock_conn_extra)
self.connection.host = host
self.connection.port = port

# Mock AWS Connection
mock_aws_hook_instance = mock_aws_hook_class.return_value
mock_aws_hook_instance.region_name = "us-east-1"

assert (
self.db_hook._get_openlineage_redshift_authority_part(self.connection)
== f"{expected_host}:{port or 5439}"
)


@pytest.mark.backend("postgres")
class TestPostgresHook:
Expand Down
79 changes: 78 additions & 1 deletion tests/providers/postgres/operators/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from airflow.models.dag import DAG
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.postgres.operators.postgres import PostgresOperator
from airflow.utils import timezone

Expand All @@ -38,7 +39,6 @@ def setup_method(self):

def teardown_method(self):
tables_to_drop = ["test_postgres_to_postgres", "test_airflow"]
from airflow.providers.postgres.hooks.postgres import PostgresHook

with PostgresHook().get_conn() as conn:
with conn.cursor() as cur:
Expand Down Expand Up @@ -113,3 +113,80 @@ def test_runtime_parameter_setting(self):
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert op.get_db_hook().get_first("SHOW statement_timeout;")[0] == "3s"


@pytest.mark.backend("postgres")
class TestPostgresOpenLineage:
custom_schemas = ["another_schema"]

def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag

with PostgresHook().get_conn() as conn:
with conn.cursor() as cur:
for schema in self.custom_schemas:
cur.execute(f"CREATE SCHEMA {schema}")

def teardown_method(self):
tables_to_drop = ["test_postgres_to_postgres", "test_airflow"]

with PostgresHook().get_conn() as conn:
with conn.cursor() as cur:
for table in tables_to_drop:
cur.execute(f"DROP TABLE IF EXISTS {table}")
for schema in self.custom_schemas:
cur.execute(f"DROP SCHEMA {schema} CASCADE")

def test_postgres_operator_openlineage_implicit_schema(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
op = PostgresOperator(
task_id="basic_postgres",
sql=sql,
dag=self.dag,
hook_params={"options": "-c search_path=another_schema"},
)

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 0
assert len(lineage.outputs) == 0
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

# OpenLineage provider runs same method on complete by default
lineage_on_complete = op.get_openlineage_facets_on_start()
assert len(lineage_on_complete.inputs) == 0
assert len(lineage_on_complete.outputs) == 1
assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432"
assert lineage_on_complete.outputs[0].name == "airflow.another_schema.test_airflow"
assert "schema" in lineage_on_complete.outputs[0].facets

def test_postgres_operator_openlineage_explicit_schema(self):
sql = """
CREATE TABLE IF NOT EXISTS public.test_airflow (
dummy VARCHAR(50)
);
"""
op = PostgresOperator(
task_id="basic_postgres",
sql=sql,
dag=self.dag,
hook_params={"options": "-c search_path=another_schema"},
)

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 0
assert len(lineage.outputs) == 0
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

# OpenLineage provider runs same method on complete by default
lineage_on_complete = op.get_openlineage_facets_on_start()
assert len(lineage_on_complete.inputs) == 0
assert len(lineage_on_complete.outputs) == 1
assert lineage_on_complete.outputs[0].namespace == "postgres://postgres:5432"
assert lineage_on_complete.outputs[0].name == "airflow.public.test_airflow"
assert "schema" in lineage_on_complete.outputs[0].facets