Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new profile mapping configuration for Teradata #1077

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .teradata.user_pass import TeradataUserPasswordProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
from .trino.ldap import TrinoLDAPProfileMapping
Expand All @@ -39,6 +40,7 @@
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
ExasolUserPasswordProfileMapping,
TeradataUserPasswordProfileMapping,
TrinoLDAPProfileMapping,
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
Expand Down Expand Up @@ -79,6 +81,7 @@ def get_automatic_profile_mapping(
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TeradataUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
Expand Down
Empty file.
62 changes: 62 additions & 0 deletions cosmos/profiles/teradata/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Maps Airflow Snowflake connections to dbt profiles if they use a user/password."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
tatiana marked this conversation as resolved.
Show resolved Hide resolved
pass

Check warning on line 10 in cosmos/profiles/teradata/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/teradata/user_pass.py#L10

Added line #L10 was not covered by tests


class TeradataUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Teradata connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/docs/core/connect-data-platform/teradata-setup
https://airflow.apache.org/docs/apache-airflow-providers-teradata/stable/connections/teradata.html
"""

airflow_connection_type: str = "teradata"
dbt_profile_type: str = "teradata"
sc250072 marked this conversation as resolved.
Show resolved Hide resolved

required_fields = [
"host",
"user",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"user": "login",
"password": "password",
"schema": "schema",
"tmode": "extra.tmode",
}

@property
def profile(self) -> dict[str, Any | None]:
sc250072 marked this conversation as resolved.
Show resolved Hide resolved
"""Gets profile. The password is stored in an environment variable."""
profile = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}
# schema is not mandatory in teradata. In teradata user itself a database so if schema is not mentioned
# in both airflow connection and profile_args then treating user as schema.
if "schema" not in self.profile_args and self.mapped_params.get("schema") is None:
profile["schema"] = profile["user"]

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"""Gets mock profile."""
parent_mock = super().mock_profile

return {
**parent_mock,
}
tatiana marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dbt-all = [
"dbt-redshift",
"dbt-snowflake",
"dbt-spark",
"dbt-teradata",
"dbt-vertica",
]
dbt-athena = ["dbt-athena-community", "apache-airflow-providers-amazon>=8.0.0"]
Expand All @@ -62,6 +63,7 @@ dbt-postgres = ["dbt-postgres"]
dbt-redshift = ["dbt-redshift"]
dbt-snowflake = ["dbt-snowflake"]
dbt-spark = ["dbt-spark"]
dbt-teradata = ["dbt-teradata"]
dbt-vertica = ["dbt-vertica<=1.5.4"]
openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"]
all = ["astronomer-cosmos[dbt-all]", "astronomer-cosmos[openlineage]"]
Expand Down
Empty file.
176 changes: 176 additions & 0 deletions tests/profiles/teradata/test_teradata_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Tests for the postgres profile."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.teradata.user_pass import TeradataUserPasswordProfileMapping


@pytest.fixture()
def mock_teradata_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_teradata_connection",
conn_type="teradata",
host="my_host",
login="my_user",
password="my_password",
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_teradata_conn_custom_tmode(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_teradata_connection",
conn_type="teradata",
host="my_host",
login="my_user",
password="my_password",
schema="my_database",
extra='{"tmode": "TERA"}',
tatiana marked this conversation as resolved.
Show resolved Hide resolved
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the teradata profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == teradata
# and the following exist:
# - host
# - user
# - password
potential_values: dict[str, str] = {
"conn_type": "teradata",
"host": "my_host",
"login": "my_user",
"password": "my_password",
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# Even there is no schema, making user as schema as user itself schema in teradata
conn = Connection(**{k: v for k, v in potential_values.items() if k != "schema"})
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": None})
assert profile_mapping.can_claim_connection()
# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
)
assert isinstance(profile_mapping, TeradataUserPasswordProfileMapping)


def test_profile_mapping_keeps_port(mock_teradata_conn: Connection) -> None:
# port is not handled in airflow connection so adding it as profile_args
profile = TeradataUserPasswordProfileMapping(mock_teradata_conn.conn_id, profile_args={"port": 1025})
assert profile.profile["port"] == 1025


def test_profile_mapping_keeps_custom_tmode(mock_teradata_conn_custom_tmode: Connection) -> None:
profile = TeradataUserPasswordProfileMapping(mock_teradata_conn_custom_tmode.conn_id)
assert profile.profile["tmode"] == "TERA"


def test_profile_args(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_database"},
)
assert profile_mapping.profile_args == {
"schema": "my_database",
}

assert profile_mapping.profile == {
"type": mock_teradata_conn.conn_type,
"host": mock_teradata_conn.host,
"user": mock_teradata_conn.login,
"password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}",
"schema": "my_database",
}


def test_profile_args_overrides(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
}

assert profile_mapping.profile == {
"type": mock_teradata_conn.conn_type,
"host": mock_teradata_conn.host,
"user": mock_teradata_conn.login,
"password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}",
"schema": "my_schema",
}


def test_profile_env_vars(
mock_teradata_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_teradata_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_TERADATA_PASSWORD": mock_teradata_conn.password,
}


def test_mock_profile() -> None:
"""
Tests that the mock profile port value get set correctly.
"""
profile = TeradataUserPasswordProfileMapping("mock_conn_id")
assert profile.mock_profile.get("host") == "mock_value"
Loading