Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: ensure operator execute method is consistent across all execution base subclasses #805

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def build_cmd(

return dbt_cmd, env

@abstractmethod
def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any:
"""Override this method for the operator to execute the dbt command"""

def execute(self, context: Context) -> Any | None: # type: ignore
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())


class DbtBuildMixin:
"""Mixin for dbt build command."""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def build_command(self, context: Context, cmd_flags: list[str] | None = None) ->
self.environment: dict[str, Any] = {**env_vars, **self.environment}
self.command: list[str] = dbt_cmd

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtBuildDockerOperator(DbtBuildMixin, DbtDockerBaseOperator):
"""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def build_kube_args(self, context: Context, cmd_flags: list[str] | None = None)
self.build_env_args(env_vars)
self.arguments = dbt_cmd

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtBuildKubernetesOperator(DbtBuildMixin, DbtKubernetesBaseOperator):
"""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,6 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None
logger.info(result.output)
return result

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())

def on_kill(self) -> None:
if self.cancel_query_on_kill:
self.subprocess_hook.log.info("Sending SIGINT signal to process group")
Expand Down
18 changes: 17 additions & 1 deletion tests/operators/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest.mock import patch

from cosmos.operators.base import (
AbstractDbtBaseOperator,
Expand All @@ -14,11 +15,26 @@

def test_dbt_base_operator_is_abstract():
"""Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined."""
expected_error = "Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods? base_cmd"
expected_error = (
"Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods base_cmd, build_and_run_cmd"
)
with pytest.raises(TypeError, match=expected_error):
AbstractDbtBaseOperator()


@pytest.mark.parametrize("cmd_flags", [["--some-flag"], []])
@patch("cosmos.operators.base.AbstractDbtBaseOperator.build_and_run_cmd")
def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatch):
"""Tests that the base operator execute method calls the build_and_run_cmd method with the expected arguments."""
monkeypatch.setattr(AbstractDbtBaseOperator, "add_cmd_flags", lambda _: cmd_flags)
AbstractDbtBaseOperator.__abstractmethods__ = set()

base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir")

base_operator.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags)


@pytest.mark.parametrize(
"dbt_command, dbt_operator_class",
[
Expand Down
Loading