Skip to content

Commit

Permalink
[CP][AIRFLOW-4316] support setting kubernetes_environment_variables c…
Browse files Browse the repository at this point in the history
…onfig section from env var (twitter-forks#46)


Co-authored-by: Vishesh Jain <[email protected]>
  • Loading branch information
vshshjn7 and Vishesh Jain authored May 4, 2020
1 parent a68e2b3 commit a507ee7
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 22 deletions.
19 changes: 16 additions & 3 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_command(command):
return output


def _read_default_config_file(file_name):
def _read_default_config_file(file_name: str) -> str:
templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
file_path = os.path.join(templates_dir, file_name)
if six.PY2:
Expand Down Expand Up @@ -165,6 +165,12 @@ class AirflowConfigParser(ConfigParser):
},
}

# This method transforms option names on every read, get, or set operation.
# This changes from the default behaviour of ConfigParser from lowercasing
# to instead be case-preserving
def optionxform(self, optionstr: str) -> str:
return optionstr

def __init__(self, default_config=None, *args, **kwargs):
super(AirflowConfigParser, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -436,8 +442,15 @@ def as_dict(
opt = opt.replace('%', '%%')
if display_source:
opt = (opt, 'env var')
cfg.setdefault(section.lower(), OrderedDict()).update(
{key.lower(): opt})

section = section.lower()
# if we lower key for kubernetes_environment_variables section,
# then we won't be able to set any Airflow environment
# variables. Airflow only parse environment variables starts
# with AIRFLOW_. Therefore, we need to make it a special case.
if section != 'kubernetes_environment_variables':
key = key.lower()
cfg.setdefault(section, OrderedDict()).update({key: opt})

# add bash commands
if include_cmds:
Expand Down
2 changes: 1 addition & 1 deletion airflow/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.
#

version = '1.10.4+twtr8'
version = '1.10.4+twtr9'

10 changes: 5 additions & 5 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def test_config_throw_error_when_original_and_fallback_is_absent(self):
self.assertTrue(configuration.conf.has_option("core", "FERNET_KEY"))
self.assertFalse(configuration.conf.has_option("core", "FERNET_KEY_CMD"))

with conf_vars({('core', 'FERNET_KEY'): None}):
with conf_vars({('core', 'fernet_key'): None}):
with self.assertRaises(AirflowConfigException) as cm:
configuration.conf.get("core", "FERNET_KEY")

Expand Down Expand Up @@ -2721,7 +2721,7 @@ def test_default_backend(self, mock_send_email):

@mock.patch('airflow.utils.email.send_email_smtp')
def test_custom_backend(self, mock_send_email):
with conf_vars({('email', 'EMAIL_BACKEND'): 'tests.core.send_email_test'}):
with conf_vars({('email', 'email_backend'): 'tests.core.send_email_test'}):
utils.email.send_email('to', 'subject', 'content')
send_email_test.assert_called_with(
'to', 'subject', 'content', files=None, dryrun=False,
Expand Down Expand Up @@ -2804,7 +2804,7 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl):
def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({('smtp', 'SMTP_SSL'): 'True'}):
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp.called)
mock_smtp_ssl.assert_called_with(
Expand All @@ -2818,8 +2818,8 @@ def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({
('smtp', 'SMTP_USER'): None,
('smtp', 'SMTP_PASSWORD'): None,
('smtp', 'smtp_user'): None,
('smtp', 'smtp_password'): None,
}):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp_ssl.called)
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setUp(self):
def tearDown(self):
crypto._fernet = None

@conf_vars({('core', 'FERNET_KEY'): ''})
@conf_vars({('core', 'fernet_key'): ''})
def test_connection_extra_no_encryption(self):
"""
Tests extras on a new connection without encryption. The fernet key
Expand All @@ -47,7 +47,7 @@ def test_connection_extra_no_encryption(self):
self.assertFalse(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')

@conf_vars({('core', 'FERNET_KEY'): Fernet.generate_key().decode()})
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
def test_connection_extra_with_encryption(self):
"""
Tests extras on a new connection with encryption.
Expand All @@ -63,14 +63,14 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()

with conf_vars({('core', 'FERNET_KEY'): key1.decode()}):
with conf_vars({('core', 'fernet_key'): key1.decode()}):
test_connection = Connection(extra='testextra')
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')

# Test decrypt of old value with new key
with conf_vars({('core', 'FERNET_KEY'): ','.join([key2.decode(), key1.decode()])}):
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
crypto._fernet = None
self.assertEqual(test_connection.extra, 'testextra')

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,8 @@ def test_email_alert_with_config(self, mock_send_email):
ti = TI(
task=task, execution_date=datetime.datetime.now())

configuration.set('email', 'SUBJECT_TEMPLATE', '/subject/path')
configuration.set('email', 'HTML_CONTENT_TEMPLATE', '/html_content/path')
configuration.set('email', 'subject_template', '/subject/path')
configuration.set('email', 'html_content_template', '/html_content/path')

opener = mock_open(read_data='template: {{ti.task_id}}')
with patch('airflow.models.taskinstance.open', opener, create=True):
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
def tearDown(self):
crypto._fernet = None

@conf_vars({('core', 'FERNET_KEY'): ''})
@conf_vars({('core', 'fernet_key'): ''})
def test_variable_no_encryption(self):
"""
Test variables without encryption
Expand All @@ -44,7 +44,7 @@ def test_variable_no_encryption(self):
self.assertFalse(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')

@conf_vars({('core', 'FERNET_KEY'): Fernet.generate_key().decode()})
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
def test_variable_with_encryption(self):
"""
Test variables with encryption
Expand All @@ -62,7 +62,7 @@ def test_var_with_encryption_rotate_fernet_key(self):
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()

with conf_vars({('core', 'FERNET_KEY'): key1.decode()}):
with conf_vars({('core', 'fernet_key'): key1.decode()}):
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
Expand All @@ -71,7 +71,7 @@ def test_var_with_encryption_rotate_fernet_key(self):
self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), b'value')

# Test decrypt of old value with new key
with conf_vars({('core', 'FERNET_KEY'): ','.join([key2.decode(), key1.decode()])}):
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
crypto._fernet = None
self.assertEqual(test_var.val, 'value')

Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_email_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def _run_as_operator(self, **kwargs):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_execute(self):
with conf_vars({('email', 'EMAIL_BACKEND'): 'tests.operators.test_email_operator.send_email_test'}):
with conf_vars({('email', 'email_backend'): 'tests.operators.test_email_operator.send_email_test'}):
self._run_as_operator()
assert send_email_test.call_count == 1
33 changes: 31 additions & 2 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ConfTest(unittest.TestCase):
def setUpClass(cls):
os.environ['AIRFLOW__TESTSECTION__TESTKEY'] = 'testvalue'
os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] = 'with%percent'
os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] = 'nested'

configuration.load_test_config()
conf.set('core', 'percent', 'with%%inside')

Expand Down Expand Up @@ -93,6 +93,13 @@ def test_airflow_config_override(self):
configuration.get_airflow_config('/home//airflow'),
'/path/to/airflow/airflow.cfg')

def test_case_sensitivity(self):
# section and key are case insensitive for get method
# note: this is not the case for as_dict method
self.assertEqual(conf.get("core", "percent"), "with%inside")
self.assertEqual(conf.get("core", "PERCENT"), "with%inside")
self.assertEqual(conf.get("CORE", "PERCENT"), "with%inside")

def test_env_var_config(self):
opt = conf.get('testsection', 'testkey')
self.assertEqual(opt, 'testvalue')
Expand All @@ -102,10 +109,13 @@ def test_env_var_config(self):

self.assertTrue(conf.has_option('testsection', 'testkey'))

os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] = 'nested'
opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY')
self.assertEqual(opt, 'nested')
del os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY']

def test_conf_as_dict(self):
os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY'] = 'nested'
cfg_dict = conf.as_dict()

# test that configs are picked up
Expand All @@ -116,8 +126,9 @@ def test_conf_as_dict(self):
# test env vars
self.assertEqual(cfg_dict['testsection']['testkey'], '< hidden >')
self.assertEqual(
cfg_dict['kubernetes_environment_variables']['airflow__testsection__testkey'],
cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'],
'< hidden >')
del os.environ['AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY']

def test_conf_as_dict_source(self):
# test display_source
Expand Down Expand Up @@ -328,6 +339,24 @@ def test_getsection(self):
test_conf.getsection('testsection')
)

def test_kubernetes_environment_variables_section(self):
TEST_CONFIG = '''
[kubernetes_environment_variables]
key1 = hello
AIRFLOW_HOME = /root/airflow
'''
TEST_CONFIG_DEFAULT = '''
[kubernetes_environment_variables]
'''
test_conf = AirflowConfigParser(
default_config=parameterized_config(TEST_CONFIG_DEFAULT))
test_conf.read_string(TEST_CONFIG)

self.assertEqual(
OrderedDict([('key1', 'hello'), ('AIRFLOW_HOME', '/root/airflow')]),
test_conf.getsection('kubernetes_environment_variables')
)

def test_broker_transport_options(self):
section_dict = conf.getsection("celery_broker_transport_options")
self.assertTrue(isinstance(section_dict['visibility_timeout'], int))
Expand Down

0 comments on commit a507ee7

Please sign in to comment.