Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dbt profile config variables to mapped profile #794

Merged
merged 24 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from .athena import AthenaAccessKeyProfileMapping
from .base import BaseProfileMapping
from .base import BaseProfileMapping, DbtProfileConfigVars
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .bigquery.oauth import GoogleCloudOauthProfileMapping
Expand Down Expand Up @@ -70,6 +70,7 @@ def get_automatic_profile_mapping(
"GoogleCloudServiceAccountDictProfileMapping",
"GoogleCloudOauthProfileMapping",
"DatabricksTokenProfileMapping",
"DbtProfileConfigVars",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
Expand Down
65 changes: 60 additions & 5 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from typing import TYPE_CHECKING
import yaml
from typing import Any, Optional, Literal, Dict, TYPE_CHECKING
import warnings

from airflow.hooks.base import BaseHook
from pydantic import dataclasses
import yaml

from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
Expand All @@ -24,6 +24,31 @@
logger = get_logger(__name__)


@dataclasses.dataclass
class DbtProfileConfigVars:
send_anonymous_usage_stats: Optional[bool] = False
partial_parse: Optional[bool] = None
use_experimental_parser: Optional[bool] = None
static_parser: Optional[bool] = None
printer_width: Optional[bool] = None
write_json: Optional[bool] = None
warn_error: Optional[bool] = None
warn_error_options: Optional[Dict[Literal["include", "exclude"], Any]] = None
log_format: Optional[Literal["text", "json", "default"]] = None
debug: Optional[bool] = None
version_check: Optional[bool] = None

def as_dict(self) -> dict[str, Any] | None:
result = {
field.name: getattr(self, field.name)
for field in self.__dataclass_fields__.values()
if getattr(self, field.name) is not None
}
if result != {}:
return result
return None
ykuc marked this conversation as resolved.
Show resolved Hide resolved


class BaseProfileMapping(ABC):
"""
A base class that other profile mappings should inherit from to ensure consistency.
Expand All @@ -41,11 +66,19 @@ class BaseProfileMapping(ABC):

_conn: Connection | None = None

def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None, disable_event_tracking: bool = False):
def __init__(
self,
conn_id: str,
profile_args: dict[str, Any] | None = None,
disable_event_tracking: bool | None = None,
dbt_config_vars: DbtProfileConfigVars | None = None,
ykuc marked this conversation as resolved.
Show resolved Hide resolved
):
self.conn_id = conn_id
self.profile_args = profile_args or {}
self._validate_profile_args()
self.disable_event_tracking = disable_event_tracking
self.dbt_config_vars = dbt_config_vars
self._validate_disable_event_tracking()

def _validate_profile_args(self) -> None:
"""
Expand All @@ -66,6 +99,25 @@ class variables when creating the profile.
)
)

def _validate_disable_event_tracking(self) -> None:
"""
Check if disable_event_tracking is set and warn that it is deprecated.
"""
if self.disable_event_tracking:
warnings.warn(
"Disabling dbt event tracking is deprecated since Cosmos 1.3 and will be removed in Cosmos 2.0. "
"Use dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False) instead.",
DeprecationWarning,
)
if (
isinstance(self.dbt_config_vars, DbtProfileConfigVars)
and self.dbt_config_vars.send_anonymous_usage_stats is not None
):
raise CosmosValueError(
"Cannot set both disable_event_tracking and "
"dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats ..."
)

@property
def conn(self) -> Connection:
"Returns the Airflow connection."
Expand Down Expand Up @@ -180,6 +232,9 @@ def get_profile_file_contents(
}
}

if self.dbt_config_vars:
profile_contents["config"] = self.dbt_config_vars.as_dict()

if self.disable_event_tracking:
profile_contents["config"] = {"send_anonymous_usage_stats": False}

Expand Down
3 changes: 2 additions & 1 deletion dev/dags/cosmos_manifest_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path

from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig, LoadMode, ExecutionConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.profiles import PostgresUserPasswordProfileMapping, DbtProfileConfigVars

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
Expand All @@ -18,6 +18,7 @@
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="airflow_db",
profile_args={"schema": "public"},
dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=True),
),
)

Expand Down
40 changes: 40 additions & 0 deletions docs/templates/index.rst.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ you specify in ``ProfileConfig``.

Disabling dbt event tracking
--------------------------------

.. note:
Deprecated in v.1.4 and will be removed in v2.0.0. Use dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False) instead.
.. versionadded:: 1.3

By default `dbt will track events <https://docs.getdbt.com/reference/global-configs/usage-stats>`_ by sending anonymous usage data
Expand Down Expand Up @@ -112,6 +115,43 @@ the example below:

dag = DbtDag(profile_config=profile_config, ...)

Dbt profile config variables
--------------------------------
.. versionadded:: 1.4.0

The parts of ``profiles.yml``, which aren't specific to a particular data platform `dbt docs <https://docs.getdbt.com/docs/core/connect-data-platform/profiles.yml>`_

.. code-block:: python

from cosmos.profiles import SnowflakeUserPasswordProfileMapping, DbtProfileConfigVars

profile_config = ProfileConfig(
profile_name="my_profile_name",
target_name="my_target_name",
profile_mapping=SnowflakeUserPasswordProfileMapping(
conn_id="my_snowflake_conn_id",
profile_args={
"database": "my_snowflake_database",
"schema": "my_snowflake_schema",
},
dbt_config_vars=DbtProfileConfigVars(
send_anonymous_usage_stats=False,
partial_parse=True,
use_experimental_parse=True,
static_parser=True,
printer_width=120,
write_json=True,
warn_error=True,
warn_error_options={"include": "all"},
log_format='text',
debug=True,
version_check=True,
),
),
)

dag = DbtDag(profile_config=profile_config, ...)




Expand Down
115 changes: 113 additions & 2 deletions tests/profiles/test_base_profile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
from typing import Any

import pytest
import yaml
from pydantic.error_wrappers import ValidationError

from cosmos.profiles.base import BaseProfileMapping
from cosmos.profiles.base import BaseProfileMapping, DbtProfileConfigVars
from cosmos.exceptions import CosmosValueError


Expand Down Expand Up @@ -37,7 +39,7 @@ def test_validate_profile_args(profile_arg: str):


@pytest.mark.parametrize("disable_event_tracking", [True, False])
def test_disable_event_tracking(disable_event_tracking: str):
def test_disable_event_tracking(disable_event_tracking: bool):
"""
Tests the config block in the profile is set correctly if disable_event_tracking is set.
"""
Expand All @@ -50,3 +52,112 @@ def test_disable_event_tracking(disable_event_tracking: str):
assert ("config" in profile_contents) == disable_event_tracking
ykuc marked this conversation as resolved.
Show resolved Hide resolved
if disable_event_tracking:
assert profile_contents["config"]["send_anonymous_usage_stats"] is False


def test_disable_event_tracking_and_send_anonymous_usage_stats():
expected_cosmos_error = (
"Cannot set both disable_event_tracking and "
"dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats ..."
)
ykuc marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(CosmosValueError) as err_info:
TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False),
disable_event_tracking=True,
)
assert err_info.value.args[0] == expected_cosmos_error


def test_dbt_profile_config_vars_none():
"""
Tests the DbtProfileConfigVars return None.
"""
dbt_config_vars = DbtProfileConfigVars(
send_anonymous_usage_stats=None,
partial_parse=None,
use_experimental_parser=None,
static_parser=None,
printer_width=None,
write_json=None,
warn_error=None,
warn_error_options=None,
log_format=None,
debug=None,
version_check=None,
)
assert dbt_config_vars.as_dict() is None


@pytest.mark.parametrize("config", [True, False])
def test_dbt_config_vars_config(config: bool):
"""
Tests the config block in the profile is set correctly.
"""

dbt_config_vars = None
if config:
dbt_config_vars = DbtProfileConfigVars(debug=False)

test_profile = TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=dbt_config_vars,
)
profile_contents = yaml.safe_load(test_profile.get_profile_file_contents(profile_name="fake-profile-name"))

assert ("config" in profile_contents) == config


@pytest.mark.parametrize("dbt_config_var,dbt_config_value", [("debug", True), ("debug", False)])
def test_validate_dbt_config_vars(dbt_config_var: str, dbt_config_value: Any):
"""
Tests the config block in the profile is set correctly.
"""
dbt_config_vars = {dbt_config_var: dbt_config_value}
test_profile = TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars),
)
profile_contents = yaml.safe_load(test_profile.get_profile_file_contents(profile_name="fake-profile-name"))

assert "config" in profile_contents
assert profile_contents["config"][dbt_config_var] == dbt_config_value


@pytest.mark.parametrize(
"dbt_config_var,dbt_config_value",
[("send_anonymous_usage_stats", 2), ("send_anonymous_usage_stats", "aaa")],
)
def test_profile_config_validate_dbt_config_vars_check_unexpected_types(dbt_config_var: str, dbt_config_value: Any):
dbt_config_vars = {dbt_config_var: dbt_config_value}

with pytest.raises(ValidationError):
TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars),
)


@pytest.mark.parametrize("dbt_config_var,dbt_config_value", [("send_anonymous_usage_stats", True)])
def test_profile_config_validate_dbt_config_vars_check_expected_types(dbt_config_var: str, dbt_config_value: Any):
dbt_config_vars = {dbt_config_var: dbt_config_value}

profile_config = TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars),
)
assert profile_config.dbt_config_vars.as_dict() == dbt_config_vars


@pytest.mark.parametrize(
"dbt_config_var,dbt_config_value",
[("log_format", "text2")],
)
def test_profile_config_validate_dbt_config_vars_check_values(dbt_config_var: str, dbt_config_value: Any):
dbt_config_vars = {dbt_config_var: dbt_config_value}

with pytest.raises(ValidationError):
TestProfileMapping(
conn_id="fake_conn_id",
dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars),
)
Loading