From 37daea78e81ecab318ec0c4ad3c66a8b4e2f73ca Mon Sep 17 00:00:00 2001 From: Pratiksha <128999446+Prab-27@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:29:01 +0530 Subject: [PATCH] bring back accidentally removed test 'test_execute_openlineage_events()' (#44832) Co-authored-by: pratiksha rajendrabhai badheka --- providers/tests/trino/hooks/test_trino.py | 61 +++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/providers/tests/trino/hooks/test_trino.py b/providers/tests/trino/hooks/test_trino.py index a9a4137da8de1..b0608023118e3 100644 --- a/providers/tests/trino/hooks/test_trino.py +++ b/providers/tests/trino/hooks/test_trino.py @@ -27,6 +27,12 @@ from airflow.exceptions import AirflowException from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.trino.hooks.trino import TrinoHook from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -37,6 +43,8 @@ TRINO_DBAPI_CONNECT = "airflow.providers.trino.hooks.trino.trino.dbapi.connect" JWT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.JWTAuthentication" CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication" +TRINO_CONN_ID = "test_trino" +TRINO_DEFAULT = "trino_default" @pytest.fixture @@ -383,3 +391,56 @@ def test_connection_failure(self, mock_conn): def test_serialize_cell(self): assert self.db_hook._serialize_cell("foo", None) == "foo" assert self.db_hook._serialize_cell(1, None) == 1 + + +def test_execute_openlineage_events(): + DB_NAME = "tpch" + DB_SCHEMA_NAME = "sf1" + + class TrinoHookForTests(TrinoHook): + get_conn = mock.MagicMock(name="conn") + get_connection = mock.MagicMock() + + def get_first(self, *_): + return [f"{DB_NAME}.{DB_SCHEMA_NAME}"] + + dbapi_hook = TrinoHookForTests() + + sql = "SELECT name FROM tpch.sf1.customer LIMIT 3" + op = SQLExecuteQueryOperator(task_id="trino_test", sql=sql, conn_id=TRINO_DEFAULT) + op._hook = dbapi_hook + rows = [ + (DB_SCHEMA_NAME, "customer", "custkey", 1, "bigint", DB_NAME), + (DB_SCHEMA_NAME, "customer", "name", 2, "varchar(25)", DB_NAME), + (DB_SCHEMA_NAME, "customer", "address", 3, "varchar(40)", DB_NAME), + (DB_SCHEMA_NAME, "customer", "nationkey", 4, "bigint", DB_NAME), + (DB_SCHEMA_NAME, "customer", "phone", 5, "varchar(15)", DB_NAME), + (DB_SCHEMA_NAME, "customer", "acctbal", 6, "double", DB_NAME), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id=TRINO_DEFAULT, + conn_type="trino", + host="trino", + port=8080, + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_start() + assert lineage.inputs == [ + Dataset( + namespace="trino://trino:8080", + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="custkey", type="bigint"), + SchemaDatasetFacetFields(name="name", type="varchar(25)"), + SchemaDatasetFacetFields(name="address", type="varchar(40)"), + SchemaDatasetFacetFields(name="nationkey", type="bigint"), + SchemaDatasetFacetFields(name="phone", type="varchar(15)"), + SchemaDatasetFacetFields(name="acctbal", type="double"), + ] + ) + }, + ) + ]