Skip to content

Commit

Permalink
Fix tests/models/test_variable.py for database isolation mode (#41414)
Browse files Browse the repository at this point in the history
* Fix tests/models/test_variable.py for database isolation mode

* Review feedback
  • Loading branch information
jscheffl authored Aug 13, 2024
1 parent f640544 commit 736ebfe
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 12 deletions.
9 changes: 5 additions & 4 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]:
# XCom.get_many, # Not supported because it returns query
XCom.clear,
XCom.set,
Variable.set,
Variable.update,
Variable.delete,
Variable._set,
Variable._update,
Variable._delete,
DAG.fetch_callback,
DAG.fetch_dagrun,
DagRun.fetch_task_instances,
Expand Down Expand Up @@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
response = json.dumps(output_json) if output_json is not None else None
log.info("Sending response: %s", response)
return Response(response=response, headers={"Content-Type": "application/json"})
except AirflowException as e: # In case of AirflowException transport the exception class back to caller
# In case of AirflowException or other selective known types, transport the exception class back to caller
except (KeyError, AttributeError, AirflowException) as e:
exception_json = BaseSerialization.serialize(e, use_pydantic_models=True)
response = json.dumps(exception_json)
log.info("Sending exception response: %s", response)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_internal/internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def wrapper(*args, **kwargs):
if result is None or result == b"":
return None
result = BaseSerialization.deserialize(json.loads(result), use_pydantic_models=True)
if isinstance(result, AirflowException):
if isinstance(result, (KeyError, AttributeError, AirflowException)):
raise result
return result

Expand Down
66 changes: 63 additions & 3 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def get(

@staticmethod
@provide_session
@internal_api_call
def set(
key: str,
value: Any,
Expand All @@ -167,6 +166,35 @@ def set(
This operation overwrites an existing variable.
:param key: Variable Key
:param value: Value to set for the Variable
:param description: Description of the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
Variable._set(
key=key, value=value, description=description, serialize_json=serialize_json, session=session
)
# invalidate key in cache for faster propagation
# we cannot save the value set because it's possible that it's shadowed by a custom backend
# (see call to check_for_write_conflict above)
SecretCache.invalidate_variable(key)

@staticmethod
@provide_session
@internal_api_call
def _set(
key: str,
value: Any,
description: str | None = None,
serialize_json: bool = False,
session: Session = None,
) -> None:
"""
Set a value for an Airflow Variable with a given Key.
This operation overwrites an existing variable.
:param key: Variable Key
:param value: Value to set for the Variable
:param description: Description of the Variable
Expand All @@ -190,7 +218,6 @@ def set(

@staticmethod
@provide_session
@internal_api_call
def update(
key: str,
value: Any,
Expand All @@ -200,6 +227,27 @@ def update(
"""
Update a given Airflow Variable with the Provided value.
:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
Variable._update(key=key, value=value, serialize_json=serialize_json, session=session)
# We need to invalidate the cache for internal API cases on the client side
SecretCache.invalidate_variable(key)

@staticmethod
@provide_session
@internal_api_call
def _update(
key: str,
value: Any,
serialize_json: bool = False,
session: Session = None,
) -> None:
"""
Update a given Airflow Variable with the Provided value.
:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
Expand All @@ -219,11 +267,23 @@ def update(

@staticmethod
@provide_session
@internal_api_call
def delete(key: str, session: Session = None) -> int:
"""
Delete an Airflow Variable for a given key.
:param key: Variable Keys
"""
rows = Variable._delete(key=key, session=session)
SecretCache.invalidate_variable(key)
return rows

@staticmethod
@provide_session
@internal_api_call
def _delete(key: str, session: Session = None) -> int:
"""
Delete an Airflow Variable for a given key.
:param key: Variable Keys
"""
rows = session.execute(delete(Variable).where(Variable.key == key)).rowcount
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum):
RELATIVEDELTA = "relativedelta"
BASE_TRIGGER = "base_trigger"
AIRFLOW_EXC_SER = "airflow_exc_ser"
BASE_EXC_SER = "base_exc_ser"
DICT = "dict"
SET = "set"
TUPLE = "tuple"
Expand Down
16 changes: 14 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,15 @@ def serialize(
),
type_=DAT.AIRFLOW_EXC_SER,
)
elif isinstance(var, (KeyError, AttributeError)):
return cls._encode(
cls.serialize(
{"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}},
use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_EXC_SER,
)
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
Expand Down Expand Up @@ -834,13 +843,16 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return decode_timezone(var)
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
elif type_ == DAT.AIRFLOW_EXC_SER:
elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
exc_cls_name = deser["exc_cls_name"]
args = deser["args"]
kwargs = deser["kwargs"]
del deser
exc_cls = import_string(exc_cls_name)
if type_ == DAT.AIRFLOW_EXC_SER:
exc_cls = import_string(exc_cls_name)
else:
exc_cls = import_string(f"builtins.{exc_cls_name}")
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def setup_test_cases(self):
db.clear_db_variables()
crypto._fernet = None

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): "True"})
def test_variable_no_encryption(self, session):
"""
Expand All @@ -60,6 +61,7 @@ def test_variable_no_encryption(self, session):
# should mask anything. That logic is tested in test_secrets_masker.py
self.mask_secret.assert_called_once_with("value", "key")

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_variable_with_encryption(self, session):
"""
Expand All @@ -70,6 +72,7 @@ def test_variable_with_encryption(self, session):
assert test_var.is_encrypted
assert test_var.val == "value"

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet
@pytest.mark.parametrize("test_value", ["value", ""])
def test_var_with_encryption_rotate_fernet_key(self, test_value, session):
"""
Expand Down Expand Up @@ -152,6 +155,7 @@ def test_variable_update(self, session):
Variable.update(key="test_key", value="value2", session=session)
assert "value2" == Variable.get("test_key")

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, API server has other ENV
def test_variable_update_fails_on_non_metastore_variable(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"):
with pytest.raises(AttributeError):
Expand Down Expand Up @@ -281,6 +285,7 @@ def test_caching_caches(self, mock_ensure_secrets: mock.Mock):
mock_backend.get_variable.assert_called_once() # second call was not made because of cache
assert first == second

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other env
def test_cache_invalidation_on_set(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"):
a = Variable.get("key") # value is saved in cache
Expand Down Expand Up @@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m
val=variable_value,
)
session.add(var)
session.flush()
session.commit()
# Make sure we re-load it, not just get the cached object back
session.expunge(var)
_secrets_masker().patterns = set()
Expand All @@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m
for expected_masked_value in expected_masked_values:
assert expected_masked_value in _secrets_masker().patterns
finally:
session.rollback()
db.clear_db_variables()

0 comments on commit 736ebfe

Please sign in to comment.