Skip to content

Commit

Permalink
Adding feature in bash operator to append the user defined env variab…
Browse files Browse the repository at this point in the history
…le to system env variable (#18944)
  • Loading branch information
PraveenA95 authored Oct 13, 2021
1 parent b2045d6 commit d4a3d2b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
14 changes: 13 additions & 1 deletion airflow/operators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class BashOperator(BaseOperator):
of inheriting the current process environment, which is the default
behavior. (templated)
:type env: dict
:param append_env: If False(default) uses the environment variables passed in env params
and does not inherit the current process environment. If True, inherits the environment variables
from current passes and then environment variable passed by the user will either update the existing
inherited environment variables or the new variables gets appended to it
:type append_env: bool
:param output_encoding: Output encoding of bash command
:type output_encoding: str
:param skip_exit_code: If task exits with this exit code, leave the task
Expand Down Expand Up @@ -135,6 +140,7 @@ def __init__(
*,
bash_command: str,
env: Optional[Dict[str, str]] = None,
append_env: bool = False,
output_encoding: str = 'utf-8',
skip_exit_code: int = 99,
cwd: str = None,
Expand All @@ -146,6 +152,7 @@ def __init__(
self.output_encoding = output_encoding
self.skip_exit_code = skip_exit_code
self.cwd = cwd
self.append_env = append_env
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")

Expand All @@ -156,9 +163,14 @@ def subprocess_hook(self):

def get_env(self, context):
"""Builds the set of environment variables to be exposed for the bash command"""
system_env = os.environ.copy()
env = self.env
if env is None:
env = os.environ.copy()
env = system_env
else:
if self.append_env:
system_env.update(env)
env = system_env

airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
self.log.debug(
Expand Down
29 changes: 19 additions & 10 deletions tests/operators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from parameterized import parameterized

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import DagRun
from airflow.models.dag import DAG
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
Expand All @@ -38,11 +37,26 @@


class TestBashOperator(unittest.TestCase):
def test_echo_env_variables(self):
@parameterized.expand(
[
(False, None, 'MY_PATH_TO_AIRFLOW_HOME'),
(True, {'AIRFLOW_HOME': 'OVERRIDDEN_AIRFLOW_HOME'}, 'OVERRIDDEN_AIRFLOW_HOME'),
]
)
def test_echo_env_variables(self, append_env, user_defined_env, expected_airflow_home):
"""
Test that env variables are exported correctly to the task bash environment.
"""
utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
expected = (
f"{expected_airflow_home}\n"
"AWESOME_PYTHONPATH\n"
"bash_op_test\n"
"echo_env_vars\n"
f"{utc_now.isoformat()}\n"
f"manual__{utc_now.isoformat()}\n"
)

dag = DAG(
dag_id='bash_op_test',
default_args={'owner': 'airflow', 'retries': 100, 'start_date': DEFAULT_DATE},
Expand All @@ -68,6 +82,8 @@ def test_echo_env_variables(self):
'echo $AIRFLOW_CTX_TASK_ID>> {0};'
'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};'
'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name),
append_env=append_env,
env=user_defined_env,
)

with mock.patch.dict(
Expand All @@ -77,13 +93,7 @@ def test_echo_env_variables(self):

with open(tmp_file.name) as file:
output = ''.join(file.readlines())
assert 'MY_PATH_TO_AIRFLOW_HOME' in output
# exported in run-tests as part of PYTHONPATH
assert 'AWESOME_PYTHONPATH' in output
assert 'bash_op_test' in output
assert 'echo_env_vars' in output
assert utc_now.isoformat() in output
assert DagRun.generate_run_id(DagRunType.MANUAL, utc_now) in output
assert expected == output

@parameterized.expand(
[
Expand Down Expand Up @@ -147,7 +157,6 @@ def test_cwd_is_file(self):
BashOperator(task_id='abc', bash_command=test_cmd, cwd=tmp_file.name).execute({})

def test_valid_cwd(self):

test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
with TemporaryDirectory(prefix='test_command_with_cwd') as test_cwd_folder:
# Test everything went alright
Expand Down

0 comments on commit d4a3d2b

Please sign in to comment.