Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor common executor constructors with test coverage #774

Merged
Merged
119 changes: 113 additions & 6 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from typing import Any, Sequence, Tuple
from abc import ABCMeta, abstractmethod

import yaml
from airflow.models.baseoperator import BaseOperator
Expand All @@ -15,14 +16,13 @@
logger = get_logger(__name__)


class DbtBaseOperator(BaseOperator):
class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
"""
Executes a dbt core cli command.

:param project_dir: Which directory to look in for the dbt_project.yml file. Default is the current working
directory and its parents.
:param conn_id: The airflow connection to use as the target
:param base_cmd: dbt sub-command to run (i.e ls, seed, run, test, etc.)
:param select: dbt optional argument that specifies which nodes to include.
:param exclude: dbt optional argument that specifies which models to exclude.
:param selector: dbt optional argument - the selector name to use, as defined in selectors.yml
Expand Down Expand Up @@ -78,11 +78,15 @@ class DbtBaseOperator(BaseOperator):

intercept_flag = True

@property
@abstractmethod
def base_cmd(self) -> list[str]:
"""Override this property to set the dbt sub-command (i.e ls, seed, run, test, etc.) for the operator"""

def __init__(
self,
project_dir: str,
conn_id: str | None = None,
base_cmd: list[str] | None = None,
select: str | None = None,
exclude: str | None = None,
selector: str | None = None,
Expand All @@ -109,7 +113,6 @@ def __init__(
) -> None:
self.project_dir = project_dir
self.conn_id = conn_id
self.base_cmd = base_cmd
self.select = select
self.exclude = exclude
self.selector = selector
Expand Down Expand Up @@ -203,6 +206,10 @@ def add_global_flags(self) -> list[str]:
flags.append(f"--{global_boolean_flag.replace('_', '-')}")
return flags

def add_cmd_flags(self) -> list[str]:
"""Allows subclasses to override to add flags for their dbt command"""
return []

def build_cmd(
self,
context: Context,
Expand All @@ -212,8 +219,7 @@ def build_cmd(

dbt_cmd.extend(self.dbt_cmd_global_flags)

if self.base_cmd:
dbt_cmd.extend(self.base_cmd)
dbt_cmd.extend(self.base_cmd)

if self.indirect_selection:
dbt_cmd += ["--indirect-selection", self.indirect_selection]
Expand All @@ -231,3 +237,104 @@ def build_cmd(
env = self.get_env(context)

return dbt_cmd, env


class DbtLSMixin:
"""
Executes a dbt core ls command.
"""

base_cmd = ["ls"]
ui_color = "#DBCDF6"


class DbtSeedMixin:
"""
Mixin for dbt seed operation command.

:param full_refresh: whether to add the flag --full-refresh to the dbt seed command
"""

base_cmd = ["seed"]
ui_color = "#F58D7E"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtSnapshotMixin:
"""Mixin for a dbt snapshot command."""

base_cmd = ["snapshot"]
ui_color = "#964B00"


class DbtRunMixin:
"""
Mixin for dbt run command.

:param full_refresh: whether to add the flag --full-refresh to the dbt seed command
"""

base_cmd = ["run"]
ui_color = "#7352BA"
ui_fgcolor = "#F4F2FC"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtTestMixin:
"""Mixin for dbt test command."""

base_cmd = ["test"]
ui_color = "#8194E0"


class DbtRunOperationMixin:
"""
Mixin for dbt run operation command.

:param macro_name: name of macro to execute
:param args: Supply arguments to the macro. This dictionary will be mapped to the keyword arguments defined in the
selected macro.
"""

ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)

@property
def base_cmd(self) -> list[str]:
return ["run-operation", self.macro_name]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.args is not None:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags
90 changes: 22 additions & 68 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

from typing import Any, Callable, Sequence

import yaml
from airflow.utils.context import Context

from cosmos.log import get_logger
from cosmos.operators.base import DbtBaseOperator
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtRunMixin,
DbtSeedMixin,
DbtSnapshotMixin,
DbtTestMixin,
DbtLSMixin,
DbtRunOperationMixin,
)

logger = get_logger(__name__)

Expand All @@ -20,13 +27,15 @@
)


class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): # type: ignore
class DbtDockerBaseOperator(DockerOperator, AbstractDbtBaseOperator): # type: ignore
"""
Executes a dbt core cli command in a Docker container.

"""

template_fields: Sequence[str] = tuple(list(DbtBaseOperator.template_fields) + list(DockerOperator.template_fields))
template_fields: Sequence[str] = tuple(
list(AbstractDbtBaseOperator.template_fields) + list(DockerOperator.template_fields)
)

intercept_flag = False

Expand Down Expand Up @@ -57,85 +66,48 @@ def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtLSDockerOperator(DbtDockerBaseOperator):
class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator):
"""
Executes a dbt core ls command.
"""

ui_color = "#DBCDF6"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["ls"]


class DbtSeedDockerOperator(DbtDockerBaseOperator):
class DbtSeedDockerOperator(DbtSeedMixin, DbtDockerBaseOperator):
"""
Executes a dbt core seed command.

:param full_refresh: dbt optional arg - dbt will treat incremental models as table models
"""

ui_color = "#F58D7E"

def __init__(self, full_refresh: bool = False, **kwargs: str) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
self.base_cmd = ["seed"]
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags

def execute(self, context: Context) -> None:
cmd_flags = self.add_cmd_flags()
self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)


class DbtSnapshotDockerOperator(DbtDockerBaseOperator):
class DbtSnapshotDockerOperator(DbtSnapshotMixin, DbtDockerBaseOperator):
"""
Executes a dbt core snapshot command.

"""

ui_color = "#964B00"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["snapshot"]


class DbtRunDockerOperator(DbtDockerBaseOperator):
class DbtRunDockerOperator(DbtRunMixin, DbtDockerBaseOperator):
"""
Executes a dbt core run command.
"""

ui_color = "#7352BA"
ui_fgcolor = "#F4F2FC"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["run"]
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]


class DbtTestDockerOperator(DbtDockerBaseOperator):
class DbtTestDockerOperator(DbtTestMixin, DbtDockerBaseOperator):
"""
Executes a dbt core test command.
"""

ui_color = "#8194E0"

def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["test"]
# as of now, on_warning_callback in docker executor does nothing
self.on_warning_callback = on_warning_callback


class DbtRunOperationDockerOperator(DbtDockerBaseOperator):
class DbtRunOperationDockerOperator(DbtRunOperationMixin, DbtDockerBaseOperator):
"""
Executes a dbt core run-operation command.

Expand All @@ -144,22 +116,4 @@ class DbtRunOperationDockerOperator(DbtDockerBaseOperator):
selected macro.
"""

ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)
self.base_cmd = ["run-operation", macro_name]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.args is not None:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags

def execute(self, context: Context) -> None:
cmd_flags = self.add_cmd_flags()
self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator]
Loading
Loading