diff --git a/airflow/providers/docker/operators/docker_swarm.py b/airflow/providers/docker/operators/docker_swarm.py index 00a158956c95f..2d5373c840f17 100644 --- a/airflow/providers/docker/operators/docker_swarm.py +++ b/airflow/providers/docker/operators/docker_swarm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Run ephemeral Docker Swarm services""" -from typing import Optional +from typing import List, Optional, Union import requests from docker import types @@ -93,13 +93,40 @@ class DockerSwarmOperator(DockerOperator): Supported only if the Docker engine is using json-file or journald logging drivers. The `tty` parameter should be set to use this with Python applications. :type enable_logging: bool + :param configs: List of docker configs to be exposed to the containers of the swarm service. + The configs are ConfigReference objects as per the docker api + [https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_ + :type configs: List[docker.types.ConfigReference] + :param secrets: List of docker secrets to be exposed to the containers of the swarm service. + The secrets are SecretReference objects as per the docker create_service api. + [https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_ + :type secrets: List[docker.types.SecretReference] + :param mode: Indicate whether a service should be deployed as a replicated or global service, + and associated parameters + :type mode: docker.types.ServiceMode + :param networks: List of network names or IDs or NetworkAttachmentConfig to attach the service to. + :type networks: List[Union[str, NetworkAttachmentConfig]] """ - def __init__(self, *, image: str, enable_logging: bool = True, **kwargs) -> None: + def __init__( + self, + *, + image: str, + enable_logging: bool = True, + configs: Optional[List[types.ConfigReference]] = None, + secrets: Optional[List[types.SecretReference]] = None, + mode: Optional[types.ServiceMode] = None, + networks: Optional[List[Union[str, types.NetworkAttachmentConfig]]] = None, + **kwargs, + ) -> None: super().__init__(image=image, **kwargs) self.enable_logging = enable_logging self.service = None + self.configs = configs + self.secrets = secrets + self.mode = mode + self.networks = networks def execute(self, context) -> None: self.cli = self._get_cli() @@ -121,12 +148,16 @@ def _run_service(self) -> None: env=self.environment, user=self.user, tty=self.tty, + configs=self.configs, + secrets=self.secrets, ), restart_policy=types.RestartPolicy(condition='none'), resources=types.Resources(mem_limit=self.mem_limit), + networks=self.networks, ), name=f'airflow-{get_random_string()}', labels={'name': f'airflow__{self.dag_id}__{self.task_id}'}, + mode=self.mode, ) self.log.info('Service started: %s', str(self.service)) diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py index c41e1e2770537..8523644888de4 100644 --- a/tests/providers/docker/operators/test_docker_swarm.py +++ b/tests/providers/docker/operators/test_docker_swarm.py @@ -21,8 +21,7 @@ import pytest import requests -from docker import APIClient -from docker.types import Mount +from docker import APIClient, types from parameterized import parameterized from airflow.exceptions import AirflowException @@ -66,22 +65,31 @@ def _client_service_logs_effect(): mem_limit='128m', user='unittest', task_id='unittest', - mounts=[Mount(source='/host/path', target='/container/path', type='bind')], + mounts=[types.Mount(source='/host/path', target='/container/path', type='bind')], auto_remove=True, tty=True, + configs=[types.ConfigReference(config_id="dummy_cfg_id", config_name="dummy_cfg_name")], + secrets=[types.SecretReference(secret_id="dummy_secret_id", secret_name="dummy_secret_name")], + mode=types.ServiceMode(mode="replicated", replicas=3), + networks=["dummy_network"], ) operator.execute(None) types_mock.TaskTemplate.assert_called_once_with( - container_spec=mock_obj, restart_policy=mock_obj, resources=mock_obj + container_spec=mock_obj, + restart_policy=mock_obj, + resources=mock_obj, + networks=["dummy_network"], ) types_mock.ContainerSpec.assert_called_once_with( image='ubuntu:latest', command='env', user='unittest', - mounts=[Mount(source='/host/path', target='/container/path', type='bind')], + mounts=[types.Mount(source='/host/path', target='/container/path', type='bind')], tty=True, env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'}, + configs=[types.ConfigReference(config_id="dummy_cfg_id", config_name="dummy_cfg_name")], + secrets=[types.SecretReference(secret_id="dummy_secret_id", secret_name="dummy_secret_name")], ) types_mock.RestartPolicy.assert_called_once_with(condition='none') types_mock.Resources.assert_called_once_with(mem_limit='128m') @@ -99,6 +107,7 @@ def _client_service_logs_effect(): assert csargs == (mock_obj,) assert cskwargs['labels'] == {'name': 'airflow__adhoc_airflow__unittest'} assert cskwargs['name'].startswith('airflow-') + assert cskwargs['mode'] == types.ServiceMode(mode="replicated", replicas=3) assert client_mock.tasks.call_count == 5 client_mock.remove_service.assert_called_once_with('some_id')