From 0e99f962744724ee7f6f68ebd811c9491dff8953 Mon Sep 17 00:00:00 2001 From: "David Lum/./Affiliates/Samsung Electronics" Date: Thu, 1 Apr 2021 16:27:15 -0400 Subject: [PATCH] Validate that the extra parameter is parseable as JSON Before this commit there was a documented but unenforced limitation that the extra parameter be encoded JSON. In #15013 this issue garnered attention and motivated this PR. --- airflow/models/connection.py | 23 +++++++++++++++++++++- tests/models/test_connection.py | 35 +++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 67a3b4dd4802a..e20e7911238c4 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -137,7 +137,28 @@ def __init__( # pylint: disable=too-many-arguments self.password = password self.schema = schema self.port = port - self.extra = extra + self.extra = self.validate_extra(extra) + + @staticmethod + def validate_extra(extra: str) -> Optional[str]: + """ + `extra` parameter is a JSON encoded object. This methods validates that the data + adheres to this specification. + + :param extra: The extra section of the . + :type extra: str + + :return str + """ + if not extra: + return None + + try: + json.loads(extra) + except JSONDecodeError as e: + raise AirflowException("The `extra` section of a Connection must be valid JSON") from e + + return extra def parse_from_uri(self, **uri): """This method is deprecated. Please use uri parameter in constructor.""" diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py index 526d029694809..88e33f90ef2b4 100644 --- a/tests/models/test_connection.py +++ b/tests/models/test_connection.py @@ -72,18 +72,20 @@ def test_connection_extra_no_encryption(self): is set to a non-base64-encoded string and the extra is stored without encryption. """ - test_connection = Connection(extra='testextra') + extra = '{"test":"extra"}' + test_connection = Connection(extra=extra) assert not test_connection.is_extra_encrypted - assert test_connection.extra == 'testextra' + assert test_connection.extra == extra @conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()}) def test_connection_extra_with_encryption(self): """ Tests extras on a new connection with encryption. """ - test_connection = Connection(extra='testextra') + extra = '{"test":"extra"}' + test_connection = Connection(extra=extra) assert test_connection.is_extra_encrypted - assert test_connection.extra == 'testextra' + assert test_connection.extra == extra def test_connection_extra_with_encryption_rotate_fernet_key(self): """ @@ -92,22 +94,23 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): key1 = Fernet.generate_key() key2 = Fernet.generate_key() + extra = '{"test":"extra"}' with conf_vars({('core', 'fernet_key'): key1.decode()}): - test_connection = Connection(extra='testextra') + test_connection = Connection(extra=extra) assert test_connection.is_extra_encrypted - assert test_connection.extra == 'testextra' - assert Fernet(key1).decrypt(test_connection._extra.encode()) == b'testextra' + assert test_connection.extra == extra + assert Fernet(key1).decrypt(test_connection._extra.encode()) == bytes(extra, 'utf-8') # Test decrypt of old value with new key with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}): crypto._fernet = None - assert test_connection.extra == 'testextra' + assert test_connection.extra == extra # Test decrypt of new value with new key test_connection.rotate_fernet_key() assert test_connection.is_extra_encrypted - assert test_connection.extra == 'testextra' - assert Fernet(key2).decrypt(test_connection._extra.encode()) == b'testextra' + assert test_connection.extra == extra + assert Fernet(key2).decrypt(test_connection._extra.encode()) == bytes(extra, 'utf-8') test_from_uri_params = [ UriTestCaseConfig( @@ -548,3 +551,15 @@ def test_connection_mixed(self): ), ): Connection(conn_id="TEST_ID", uri="mysql://", schema="AAA") + + def test_connection_extra_validation_raises(self): + with self.assertRaises(AirflowException): + Connection(extra='extra') + + def test_connection_extra_validation_allows_valid_json(self): + con = Connection(extra='{"foo":"bar", "baz":"taz"}') + assert con.extra_dejson == {'foo': 'bar', 'baz': "taz"} + + def test_connection_extra_validation_allows_none(self): + con = Connection(extra=None) + assert con.extra_dejson == {}