Skip to content

Commit

Permalink
Add support for configs, secrets, networks and replicas for DockerSwa…
Browse files Browse the repository at this point in the history
…rmOperator (#17474)
  • Loading branch information
enima2684 authored Aug 19, 2021
1 parent 6ba538f commit 4da4c18
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
35 changes: 33 additions & 2 deletions airflow/providers/docker/operators/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Run ephemeral Docker Swarm services"""
from typing import Optional
from typing import List, Optional, Union

import requests
from docker import types
Expand Down Expand Up @@ -93,13 +93,40 @@ class DockerSwarmOperator(DockerOperator):
Supported only if the Docker engine is using json-file or journald logging drivers.
The `tty` parameter should be set to use this with Python applications.
:type enable_logging: bool
:param configs: List of docker configs to be exposed to the containers of the swarm service.
The configs are ConfigReference objects as per the docker api
[https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_
:type configs: List[docker.types.ConfigReference]
:param secrets: List of docker secrets to be exposed to the containers of the swarm service.
The secrets are SecretReference objects as per the docker create_service api.
[https://docker-py.readthedocs.io/en/stable/services.html#docker.models.services.ServiceCollection.create]_
:type secrets: List[docker.types.SecretReference]
:param mode: Indicate whether a service should be deployed as a replicated or global service,
and associated parameters
:type mode: docker.types.ServiceMode
:param networks: List of network names or IDs or NetworkAttachmentConfig to attach the service to.
:type networks: List[Union[str, NetworkAttachmentConfig]]
"""

def __init__(self, *, image: str, enable_logging: bool = True, **kwargs) -> None:
def __init__(
self,
*,
image: str,
enable_logging: bool = True,
configs: Optional[List[types.ConfigReference]] = None,
secrets: Optional[List[types.SecretReference]] = None,
mode: Optional[types.ServiceMode] = None,
networks: Optional[List[Union[str, types.NetworkAttachmentConfig]]] = None,
**kwargs,
) -> None:
super().__init__(image=image, **kwargs)

self.enable_logging = enable_logging
self.service = None
self.configs = configs
self.secrets = secrets
self.mode = mode
self.networks = networks

def execute(self, context) -> None:
self.cli = self._get_cli()
Expand All @@ -121,12 +148,16 @@ def _run_service(self) -> None:
env=self.environment,
user=self.user,
tty=self.tty,
configs=self.configs,
secrets=self.secrets,
),
restart_policy=types.RestartPolicy(condition='none'),
resources=types.Resources(mem_limit=self.mem_limit),
networks=self.networks,
),
name=f'airflow-{get_random_string()}',
labels={'name': f'airflow__{self.dag_id}__{self.task_id}'},
mode=self.mode,
)

self.log.info('Service started: %s', str(self.service))
Expand Down
19 changes: 14 additions & 5 deletions tests/providers/docker/operators/test_docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

import pytest
import requests
from docker import APIClient
from docker.types import Mount
from docker import APIClient, types
from parameterized import parameterized

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -66,22 +65,31 @@ def _client_service_logs_effect():
mem_limit='128m',
user='unittest',
task_id='unittest',
mounts=[Mount(source='/host/path', target='/container/path', type='bind')],
mounts=[types.Mount(source='/host/path', target='/container/path', type='bind')],
auto_remove=True,
tty=True,
configs=[types.ConfigReference(config_id="dummy_cfg_id", config_name="dummy_cfg_name")],
secrets=[types.SecretReference(secret_id="dummy_secret_id", secret_name="dummy_secret_name")],
mode=types.ServiceMode(mode="replicated", replicas=3),
networks=["dummy_network"],
)
operator.execute(None)

types_mock.TaskTemplate.assert_called_once_with(
container_spec=mock_obj, restart_policy=mock_obj, resources=mock_obj
container_spec=mock_obj,
restart_policy=mock_obj,
resources=mock_obj,
networks=["dummy_network"],
)
types_mock.ContainerSpec.assert_called_once_with(
image='ubuntu:latest',
command='env',
user='unittest',
mounts=[Mount(source='/host/path', target='/container/path', type='bind')],
mounts=[types.Mount(source='/host/path', target='/container/path', type='bind')],
tty=True,
env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'},
configs=[types.ConfigReference(config_id="dummy_cfg_id", config_name="dummy_cfg_name")],
secrets=[types.SecretReference(secret_id="dummy_secret_id", secret_name="dummy_secret_name")],
)
types_mock.RestartPolicy.assert_called_once_with(condition='none')
types_mock.Resources.assert_called_once_with(mem_limit='128m')
Expand All @@ -99,6 +107,7 @@ def _client_service_logs_effect():
assert csargs == (mock_obj,)
assert cskwargs['labels'] == {'name': 'airflow__adhoc_airflow__unittest'}
assert cskwargs['name'].startswith('airflow-')
assert cskwargs['mode'] == types.ServiceMode(mode="replicated", replicas=3)
assert client_mock.tasks.call_count == 5
client_mock.remove_service.assert_called_once_with('some_id')

Expand Down

0 comments on commit 4da4c18

Please sign in to comment.