Skip to content

Commit

Permalink
bring back accidentally removed test 'test_execute_openlineage_events…
Browse files Browse the repository at this point in the history
…()' (#44832)

Co-authored-by: pratiksha rajendrabhai badheka <pratiksha@DESKTOP-T5HUA05>
  • Loading branch information
Prab-27 and pratiksha rajendrabhai badheka authored Dec 11, 2024
1 parent 72b7877 commit 37daea7
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions providers/tests/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
)
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.trino.hooks.trino import TrinoHook

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
Expand All @@ -37,6 +43,8 @@
TRINO_DBAPI_CONNECT = "airflow.providers.trino.hooks.trino.trino.dbapi.connect"
JWT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.JWTAuthentication"
CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication"
TRINO_CONN_ID = "test_trino"
TRINO_DEFAULT = "trino_default"


@pytest.fixture
Expand Down Expand Up @@ -383,3 +391,56 @@ def test_connection_failure(self, mock_conn):
def test_serialize_cell(self):
assert self.db_hook._serialize_cell("foo", None) == "foo"
assert self.db_hook._serialize_cell(1, None) == 1


def test_execute_openlineage_events():
DB_NAME = "tpch"
DB_SCHEMA_NAME = "sf1"

class TrinoHookForTests(TrinoHook):
get_conn = mock.MagicMock(name="conn")
get_connection = mock.MagicMock()

def get_first(self, *_):
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = TrinoHookForTests()

sql = "SELECT name FROM tpch.sf1.customer LIMIT 3"
op = SQLExecuteQueryOperator(task_id="trino_test", sql=sql, conn_id=TRINO_DEFAULT)
op._hook = dbapi_hook
rows = [
(DB_SCHEMA_NAME, "customer", "custkey", 1, "bigint", DB_NAME),
(DB_SCHEMA_NAME, "customer", "name", 2, "varchar(25)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "address", 3, "varchar(40)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "nationkey", 4, "bigint", DB_NAME),
(DB_SCHEMA_NAME, "customer", "phone", 5, "varchar(15)", DB_NAME),
(DB_SCHEMA_NAME, "customer", "acctbal", 6, "double", DB_NAME),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id=TRINO_DEFAULT,
conn_type="trino",
host="trino",
port=8080,
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert lineage.inputs == [
Dataset(
namespace="trino://trino:8080",
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="custkey", type="bigint"),
SchemaDatasetFacetFields(name="name", type="varchar(25)"),
SchemaDatasetFacetFields(name="address", type="varchar(40)"),
SchemaDatasetFacetFields(name="nationkey", type="bigint"),
SchemaDatasetFacetFields(name="phone", type="varchar(15)"),
SchemaDatasetFacetFields(name="acctbal", type="double"),
]
)
},
)
]

0 comments on commit 37daea7

Please sign in to comment.