Skip to content

Commit

Permalink
Added new profile mapping configuration for Teradata (#1077)
Browse files Browse the repository at this point in the history
Teradata has
[Provider](https://airflow.apache.org/docs/apache-airflow-providers-teradata/stable/index.html)
in airflow and [adapter](https://github.com/Teradata/dbt-teradata) in
dbt. The cosmos library doesn't have profile configuration with mapping
support. This PR address this issue.

Closes: #1053
  • Loading branch information
sc250072 authored Jul 3, 2024
1 parent 30a7e5f commit 1c9e1f5
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .teradata.user_pass import TeradataUserPasswordProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
from .trino.ldap import TrinoLDAPProfileMapping
Expand All @@ -39,6 +40,7 @@
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
ExasolUserPasswordProfileMapping,
TeradataUserPasswordProfileMapping,
TrinoLDAPProfileMapping,
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
Expand Down Expand Up @@ -79,6 +81,7 @@ def get_automatic_profile_mapping(
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TeradataUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
Expand Down
Empty file.
51 changes: 51 additions & 0 deletions cosmos/profiles/teradata/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Maps Airflow Snowflake connections to dbt profiles if they use a user/password."""

from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class TeradataUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Teradata connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/docs/core/connect-data-platform/teradata-setup
https://airflow.apache.org/docs/apache-airflow-providers-teradata/stable/connections/teradata.html
"""

airflow_connection_type: str = "teradata"
dbt_profile_type: str = "teradata"
is_community = True

required_fields = [
"host",
"user",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"user": "login",
"password": "password",
"schema": "schema",
"tmode": "extra.tmode",
}

@property
def profile(self) -> dict[str, Any]:
"""Gets profile. The password is stored in an environment variable."""
profile = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}
# schema is not mandatory in teradata. In teradata user itself a database so if schema is not mentioned
# in both airflow connection and profile_args then treating user as schema.
if "schema" not in self.profile_args and self.mapped_params.get("schema") is None:
profile["schema"] = profile["user"]

return self.filter_null(profile)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dbt-all = [
"dbt-redshift",
"dbt-snowflake",
"dbt-spark",
"dbt-teradata",
"dbt-vertica",
]
dbt-athena = ["dbt-athena-community", "apache-airflow-providers-amazon>=8.0.0"]
Expand All @@ -62,6 +63,7 @@ dbt-postgres = ["dbt-postgres"]
dbt-redshift = ["dbt-redshift"]
dbt-snowflake = ["dbt-snowflake"]
dbt-spark = ["dbt-spark"]
dbt-teradata = ["dbt-teradata"]
dbt-vertica = ["dbt-vertica<=1.5.4"]
openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"]
all = ["astronomer-cosmos[dbt-all]", "astronomer-cosmos[openlineage]"]
Expand Down
Empty file.
176 changes: 176 additions & 0 deletions tests/profiles/teradata/test_teradata_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Tests for the postgres profile."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.teradata.user_pass import TeradataUserPasswordProfileMapping


@pytest.fixture()
def mock_teradata_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_teradata_connection",
conn_type="teradata",
host="my_host",
login="my_user",
password="my_password",
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_teradata_conn_custom_tmode(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_teradata_connection",
conn_type="teradata",
host="my_host",
login="my_user",
password="my_password",
schema="my_database",
extra='{"tmode": "TERA"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the teradata profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == teradata
# and the following exist:
# - host
# - user
# - password
potential_values: dict[str, str] = {
"conn_type": "teradata",
"host": "my_host",
"login": "my_user",
"password": "my_password",
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# Even there is no schema, making user as schema as user itself schema in teradata
conn = Connection(**{k: v for k, v in potential_values.items() if k != "schema"})
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": None})
assert profile_mapping.can_claim_connection()
# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
)
assert isinstance(profile_mapping, TeradataUserPasswordProfileMapping)


def test_profile_mapping_keeps_port(mock_teradata_conn: Connection) -> None:
# port is not handled in airflow connection so adding it as profile_args
profile = TeradataUserPasswordProfileMapping(mock_teradata_conn.conn_id, profile_args={"port": 1025})
assert profile.profile["port"] == 1025


def test_profile_mapping_keeps_custom_tmode(mock_teradata_conn_custom_tmode: Connection) -> None:
profile = TeradataUserPasswordProfileMapping(mock_teradata_conn_custom_tmode.conn_id)
assert profile.profile["tmode"] == "TERA"


def test_profile_args(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_database"},
)
assert profile_mapping.profile_args == {
"schema": "my_database",
}

assert profile_mapping.profile == {
"type": mock_teradata_conn.conn_type,
"host": mock_teradata_conn.host,
"user": mock_teradata_conn.login,
"password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}",
"schema": "my_database",
}


def test_profile_args_overrides(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
}

assert profile_mapping.profile == {
"type": mock_teradata_conn.conn_type,
"host": mock_teradata_conn.host,
"user": mock_teradata_conn.login,
"password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}",
"schema": "my_schema",
}


def test_profile_env_vars(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_TERADATA_PASSWORD": mock_teradata_conn.password,
}


def test_mock_profile() -> None:
"""
Tests that the mock profile port value get set correctly.
"""
profile = TeradataUserPasswordProfileMapping("mock_conn_id")
assert profile.mock_profile.get("host") == "mock_value"

0 comments on commit 1c9e1f5

Please sign in to comment.