Skip to content

Commit

Permalink
fix: execute method order for docker and kubernetes operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jbandoro committed Feb 16, 2024
1 parent 8a1fa3f commit 2f570f0
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 18 deletions.
4 changes: 2 additions & 2 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)


class DbtDockerBaseOperator(DockerOperator, AbstractDbtBaseOperator): # type: ignore
class DbtDockerBaseOperator(AbstractDbtBaseOperator, DockerOperator): # type: ignore
"""
Executes a dbt core cli command in a Docker container.
Expand All @@ -50,7 +50,7 @@ def __init__(
def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any:
self.build_command(context, cmd_flags)
self.log.info(f"Running command: {self.command}")
result = super().execute(context)
result = DockerOperator.execute(self, context)
logger.info(result)

def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)


class DbtKubernetesBaseOperator(KubernetesPodOperator, AbstractDbtBaseOperator): # type: ignore
class DbtKubernetesBaseOperator(AbstractDbtBaseOperator, KubernetesPodOperator): # type: ignore
"""
Executes a dbt core cli command in a Kubernetes Pod.
Expand Down Expand Up @@ -73,7 +73,7 @@ def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None:
def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any:
self.build_kube_args(context, cmd_flags)
self.log.info(f"Running command: {self.arguments}")
result = super().execute(context)
result = KubernetesPodOperator.execute(self, context)
logger.info(result)

def build_kube_args(self, context: Context, cmd_flags: list[str] | None = None) -> None:
Expand Down
43 changes: 36 additions & 7 deletions tests/operators/test_docker.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest

from airflow.utils.context import Context
from pendulum import datetime

from cosmos.operators.docker import (
DbtBuildDockerOperator,
DbtDockerBaseOperator,
DbtLSDockerOperator,
DbtRunDockerOperator,
DbtSeedDockerOperator,
DbtTestDockerOperator,
)


class ConcreteDbtDockerBaseOperator(DbtDockerBaseOperator):
base_cmd = ["cmd"]
@pytest.fixture()
def mock_docker_execute():
with patch("cosmos.operators.docker.DockerOperator.execute") as mock_execute:
yield mock_execute


def test_dbt_docker_operator_add_global_flags() -> None:
dbt_base_operator = ConcreteDbtDockerBaseOperator(
@pytest.fixture()
def base_operator(mock_docker_execute):
from cosmos.operators.docker import DbtDockerBaseOperator

class ConcreteDbtDockerBaseOperator(DbtDockerBaseOperator):
base_cmd = ["cmd"]

return ConcreteDbtDockerBaseOperator


def test_dbt_docker_operator_add_global_flags(base_operator) -> None:
dbt_base_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
Expand All @@ -38,12 +50,29 @@ def test_dbt_docker_operator_add_global_flags() -> None:
]


@patch("cosmos.operators.docker.DbtDockerBaseOperator.build_command")
def test_dbt_docker_operator_execute(mock_build_command, base_operator, mock_docker_execute):
"""Tests that the execute method call results in both the build_command method and the docker execute method being called."""
operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
project_dir="my/dir",
)
operator.execute(context={})
# Assert that the build_command method was called in the execution
mock_build_command.assert_called_once()
# Assert that the docker execute method was called in the execution
mock_docker_execute.assert_called_once()
assert mock_docker_execute.call_args.args[-1] == {}


@patch("cosmos.operators.base.context_to_airflow_vars")
def test_dbt_docker_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None:
def test_dbt_docker_operator_get_env(p_context_to_airflow_vars: MagicMock, base_operator) -> None:
"""
If an end user passes in a
"""
dbt_base_operator = ConcreteDbtDockerBaseOperator(
dbt_base_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
Expand Down
42 changes: 35 additions & 7 deletions tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from cosmos.operators.kubernetes import (
DbtBuildKubernetesOperator,
DbtKubernetesBaseOperator,
DbtLSKubernetesOperator,
DbtRunKubernetesOperator,
DbtSeedKubernetesOperator,
Expand All @@ -24,12 +23,24 @@
module_available = False


class ConcreteDbtKubernetesBaseOperator(DbtKubernetesBaseOperator):
base_cmd = ["cmd"]
@pytest.fixture()
def mock_kubernetes_execute():
with patch("cosmos.operators.kubernetes.KubernetesPodOperator.execute") as mock_execute:
yield mock_execute


def test_dbt_kubernetes_operator_add_global_flags() -> None:
dbt_kube_operator = ConcreteDbtKubernetesBaseOperator(
@pytest.fixture()
def base_operator(mock_kubernetes_execute):
from cosmos.operators.kubernetes import DbtKubernetesBaseOperator

class ConcreteDbtKubernetesBaseOperator(DbtKubernetesBaseOperator):
base_cmd = ["cmd"]

return ConcreteDbtKubernetesBaseOperator


def test_dbt_kubernetes_operator_add_global_flags(base_operator) -> None:
dbt_kube_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
Expand All @@ -48,12 +59,29 @@ def test_dbt_kubernetes_operator_add_global_flags() -> None:
]


@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_kube_args")
def test_dbt_kubernetes_operator_execute(mock_build_kube_args, base_operator, mock_kubernetes_execute):
"""Tests that the execute method call results in both the build_kube_args method and the kubernetes execute method being called."""
operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
project_dir="my/dir",
)
operator.execute(context={})
# Assert that the build_command method was called in the execution
mock_build_kube_args.assert_called_once()
# Assert that the docker execute method was called in the execution
mock_kubernetes_execute.assert_called_once()
assert mock_kubernetes_execute.call_args.args[-1] == {}


@patch("cosmos.operators.base.context_to_airflow_vars")
def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None:
def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock, base_operator) -> None:
"""
If an end user passes in a
"""
dbt_kube_operator = ConcreteDbtKubernetesBaseOperator(
dbt_kube_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
Expand Down

0 comments on commit 2f570f0

Please sign in to comment.