Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add skip_on_exit_code to SSHOperator #36303

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import warnings
from base64 import b64encode
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Container, Sequence

from deprecated.classic import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.models import BaseOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -60,6 +60,9 @@ class SSHOperator(BaseOperator):
The default is ``False`` but note that `get_pty` is forced to ``True``
when the `command` starts with ``sudo``.
:param banner_timeout: timeout to wait for banner from the server in seconds
:param skip_on_exit_code: If command exits with this exit code, leave the task
in ``skipped`` state (default: None). If set to ``None``, any non-zero
exit code will be treated as a failure.

If *do_xcom_push* is *True*, the numeric exit code emitted by
the ssh session is pushed to XCom under key ``ssh_exit``.
Expand Down Expand Up @@ -91,6 +94,7 @@ def __init__(
environment: dict | None = None,
get_pty: bool = False,
banner_timeout: float = 30.0,
skip_on_exit_code: int | Container[int] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -106,6 +110,13 @@ def __init__(
self.environment = environment
self.get_pty = get_pty
self.banner_timeout = banner_timeout
self.skip_on_exit_code = (
skip_on_exit_code
if isinstance(skip_on_exit_code, Container)
else [skip_on_exit_code]
if skip_on_exit_code
else []
)

@cached_property
def ssh_hook(self) -> SSHHook:
Expand Down Expand Up @@ -141,7 +152,7 @@ def get_ssh_client(self) -> SSHClient:
self.log.info("Creating ssh_client")
return self.hook.get_conn()

def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> tuple[int, bytes, bytes]:
warnings.warn(
"exec_ssh_client_command method on SSHOperator is deprecated, call "
"`ssh_hook.exec_ssh_client_command` instead",
Expand All @@ -156,6 +167,8 @@ def raise_for_status(self, exit_status: int, stderr: bytes, context=None) -> Non
if context and self.do_xcom_push:
ti = context.get("task_instance")
ti.xcom_push(key="ssh_exit", value=exit_status)
if exit_status in self.skip_on_exit_code:
raise AirflowSkipException(f"SSH command returned exit code {exit_status}. Skipping.")
if exit_status != 0:
raise AirflowException(f"SSH operator error: exit status = {exit_status}")

Expand Down
35 changes: 34 additions & 1 deletion tests/providers/ssh/operators/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
from paramiko.client import SSHClient

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import TaskInstance
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
Expand Down Expand Up @@ -203,6 +203,39 @@ def test_ssh_client_managed_correctly(self):
self.hook.get_conn.assert_called_once()
self.hook.get_conn.return_value.__exit__.assert_called_once()

@pytest.mark.parametrize(
"extra_kwargs, actual_exit_code, expected_exc",
[
({}, 0, None),
({}, 100, AirflowException),
({"skip_on_exit_code": None}, 0, None),
({"skip_on_exit_code": None}, 100, AirflowException),
({"skip_on_exit_code": 100}, 100, AirflowSkipException),
({"skip_on_exit_code": 100}, 101, AirflowException),
({"skip_on_exit_code": [100]}, 100, AirflowSkipException),
({"skip_on_exit_code": [100]}, 101, AirflowException),
({"skip_on_exit_code": [100, 102]}, 101, AirflowException),
({"skip_on_exit_code": (100,)}, 100, AirflowSkipException),
({"skip_on_exit_code": (100,)}, 101, AirflowException),
],
)
def test_skip(self, extra_kwargs, actual_exit_code, expected_exc):
command = "not_a_real_command"
self.exec_ssh_client_command.return_value = (actual_exit_code, b"", b"")

operator = SSHOperator(
task_id="test",
ssh_hook=self.hook,
command=command,
**extra_kwargs,
)

if expected_exc is None:
operator.execute({})
else:
with pytest.raises(expected_exc):
operator.execute({})
Comment on lines +233 to +237
Copy link
Contributor

@josh-fell josh-fell Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if expected_exc is None:
operator.execute({})
else:
with pytest.raises(expected_exc):
operator.execute({})
with pytest.raises(expected_exc):
operator.execute({})

It's a nit really, but you don't need the branch here.

Copy link
Contributor Author

@dolfinus dolfinus Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are test cases with exit code 0 which does not raise exception. But passing None or empty tuple to pytest.raises is not supported:
ValueError: Expected an exception type or a tuple of exception types, but got None. Raising exceptions is already understood as failing the test, so you don't need any special code to say 'this should never raise an exception'.
ValueError: Expected an exception type or a tuple of exception types, but got (). Raising exceptions is already understood as failing the test, so you don't need any special code to say 'this should never raise an exception'.


def test_command_errored(self):
# Test that run_ssh_client_command works on invalid commands
command = "not_a_real_command"
Expand Down