diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py index ceb40f177066..4eb2560fc54b 100644 --- a/airflow/providers/docker/operators/docker.py +++ b/airflow/providers/docker/operators/docker.py @@ -35,6 +35,15 @@ from airflow.utils.context import Context +def stringify(line: Union[str, bytes]): + """Make sure string is returned even if bytes are passed. Docker stream can return bytes.""" + decode_method = getattr(line, 'decode', None) + if decode_method: + return decode_method(encoding='utf-8', errors='surrogateescape') + else: + return line + + class DockerOperator(BaseOperator): """ Execute a command inside a docker container. @@ -222,7 +231,7 @@ def get_hook(self) -> DockerHook: tls=self.__get_tls_config(), ) - def _run_image(self) -> Optional[str]: + def _run_image(self) -> Optional[Union[List[str], str]]: """Run a Docker container with the provided image""" self.log.info('Starting docker container from image %s', self.image) if not self.cli: @@ -245,7 +254,9 @@ def _run_image(self) -> Optional[str]: else: return self._run_image_with_mounts(self.mounts, add_tmp_variable=False) - def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> Optional[str]: + def _run_image_with_mounts( + self, target_mounts, add_tmp_variable: bool + ) -> Optional[Union[List[str], str]]: if add_tmp_variable: self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir else: @@ -281,10 +292,7 @@ def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> Optio log_lines = [] for log_chunk in logstream: - if hasattr(log_chunk, 'decode'): - # Note that lines returned can also be byte sequences so we have to handle decode here - log_chunk = log_chunk.decode('utf-8', errors='surrogateescape') - log_chunk = log_chunk.strip() + log_chunk = stringify(log_chunk).strip() log_lines.append(log_chunk) self.log.info("%s", log_chunk) @@ -302,12 +310,15 @@ def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> Optio 'stderr': True, 'stream': True, } - - return ( - self.cli.logs(**log_parameters) - if self.xcom_all - else self.cli.logs(**log_parameters, tail=1) - ) + try: + if self.xcom_all: + return [stringify(line).strip() for line in self.cli.logs(**log_parameters)] + else: + lines = [stringify(line).strip() for line in self.cli.logs(**log_parameters, tail=1)] + return lines[-1] if lines else None + except StopIteration: + # handle the case when there is not a single line to iterate on + return None return None finally: if self.auto_remove: diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py index cda1cf2f77a3..710dad1332ab 100644 --- a/tests/providers/docker/operators/test_docker.py +++ b/tests/providers/docker/operators/test_docker.py @@ -50,12 +50,14 @@ def setUp(self): self.client_mock.pull.return_value = {"status": "pull log"} self.client_mock.wait.return_value = {"StatusCode": 0} self.client_mock.create_host_config.return_value = mock.Mock() - self.log_messages = ['container log 1', 'container log 2'] + self.log_messages = ['container log 😁 ', b'byte string container log'] self.client_mock.attach.return_value = self.log_messages # If logs() is called with tail then only return the last value, otherwise return the whole log. self.client_mock.logs.side_effect = ( - lambda **kwargs: self.log_messages[-kwargs['tail']] if 'tail' in kwargs else self.log_messages + lambda **kwargs: iter(self.log_messages[-kwargs['tail'] :]) + if 'tail' in kwargs + else iter(self.log_messages) ) self.client_class_patcher = mock.patch( @@ -429,7 +431,6 @@ def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock): def test_execute_xcom_behavior(self): self.client_mock.pull.return_value = [b'{"status":"pull log"}'] - kwargs = { 'api_version': '1.19', 'command': 'env', @@ -455,13 +456,21 @@ def test_execute_xcom_behavior(self): xcom_all_result = xcom_all_operator.execute(None) no_xcom_push_result = no_xcom_push_operator.execute(None) - assert xcom_push_result == 'container log 2' - assert xcom_all_result == ['container log 1', 'container log 2'] + assert xcom_push_result == 'byte string container log' + assert xcom_all_result == ['container log 😁', 'byte string container log'] assert no_xcom_push_result is None def test_execute_xcom_behavior_bytes(self): + self.log_messages = [b'container log 1 ', b'container log 2'] self.client_mock.pull.return_value = [b'{"status":"pull log"}'] - self.client_mock.attach.return_value = [b'container log 1 ', b'container log 2'] + self.client_mock.attach.return_value = iter([b'container log 1 ', b'container log 2']) + # Make sure the logs side effect is updated after the change + self.client_mock.logs.side_effect = ( + lambda **kwargs: iter(self.log_messages[-kwargs['tail'] :]) + if 'tail' in kwargs + else iter(self.log_messages) + ) + kwargs = { 'api_version': '1.19', 'command': 'env', @@ -487,10 +496,47 @@ def test_execute_xcom_behavior_bytes(self): xcom_all_result = xcom_all_operator.execute(None) no_xcom_push_result = no_xcom_push_operator.execute(None) + # Those values here are different than log above as they are from setup assert xcom_push_result == 'container log 2' assert xcom_all_result == ['container log 1', 'container log 2'] assert no_xcom_push_result is None + def test_execute_xcom_behavior_no_result(self): + self.log_messages = [] + self.client_mock.pull.return_value = [b'{"status":"pull log"}'] + self.client_mock.attach.return_value = iter([]) + # Make sure the logs side effect is updated after the change + self.client_mock.logs.side_effect = iter([]) + + kwargs = { + 'api_version': '1.19', + 'command': 'env', + 'environment': {'UNIT': 'TEST'}, + 'private_environment': {'PRIVATE': 'MESSAGE'}, + 'image': 'ubuntu:latest', + 'network_mode': 'bridge', + 'owner': 'unittest', + 'task_id': 'unittest', + 'mounts': [Mount(source='/host/path', target='/container/path', type='bind')], + 'working_dir': '/container/path', + 'shm_size': 1000, + 'host_tmp_dir': '/host/airflow', + 'container_name': 'test_container', + 'tty': True, + } + + xcom_push_operator = DockerOperator(**kwargs, do_xcom_push=True, xcom_all=False) + xcom_all_operator = DockerOperator(**kwargs, do_xcom_push=True, xcom_all=True) + no_xcom_push_operator = DockerOperator(**kwargs, do_xcom_push=False) + + xcom_push_result = xcom_push_operator.execute(None) + xcom_all_result = xcom_all_operator.execute(None) + no_xcom_push_result = no_xcom_push_operator.execute(None) + + assert xcom_push_result is None + assert xcom_all_result is None + assert no_xcom_push_result is None + def test_extra_hosts(self): hosts_obj = mock.Mock() operator = DockerOperator(task_id='test', image='test', extra_hosts=hosts_obj)