Skip to content

Commit

Permalink
be more explicit about ser and deser; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Oct 29, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 0fc6c0e commit 33ac787
Showing 2 changed files with 132 additions and 71 deletions.
36 changes: 23 additions & 13 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,6 @@
except ImportError:
HAS_KUBERNETES = False


if TYPE_CHECKING:
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep

@@ -409,19 +408,30 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -

@classmethod
def _serialize_param(cls, param: Param):
param_dict = param.dump()
param_class_name = param_dict.pop(Param.CLASS_IDENTIFIER)
param_kwargs = dict(default=param_dict.pop('value'), **param_dict)
return {Param.CLASS_IDENTIFIER: param_class_name, **cls._serialize(param_kwargs)}
return dict(
__class=f"{param.__module__}.{param.__class__.__name__}",
default=cls._serialize(param.value),
description=cls._serialize(param.description),
schema=cls._serialize(param.schema),
)

@classmethod
def _deserialize_param(cls, param_dict: Dict):
param_class = import_string(param_dict.pop(Param.CLASS_IDENTIFIER))
try:
param_kwargs = cls._deserialize(param_dict)
except KeyError:
param_kwargs = param_dict
return param_class(**param_kwargs)
class_name = param_dict['__class']
class_ = import_string(class_name) # type: Type[Param]
attrs = ('default', 'description', 'schema')
kwargs = {}
for attr in attrs:
if attr not in param_dict:
continue
val = param_dict[attr]
is_serialized = isinstance(val, dict) and '__type' in val
if is_serialized:
deserialized_val = cls._deserialize(param_dict[attr])
kwargs[attr] = deserialized_val
else:
kwargs[attr] = val
return class_(**kwargs)

@classmethod
def _serialize_params_dict(cls, params: ParamsDict):
@@ -437,13 +447,13 @@ def _serialize_params_dict(cls, params: ParamsDict):

@classmethod
def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
"""Deserialize a DAGs Params dict"""
"""Deserialize a DAG's Params dict"""
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
op_params[k] = cls._deserialize_param(v)
else:
# Old style params, upgrade it
# Old style params, convert it
op_params[k] = Param(v)

return ParamsDict(op_params)
167 changes: 109 additions & 58 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
@@ -735,10 +735,9 @@ def test_dag_params_roundtrip(self, val, expected_val):
dag = DAG(dag_id='simple_dag', params=val)
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
serialized_dag_json = SerializedDAG.to_json(dag)

# serialized dag dict must be json serializable
json.dumps(serialized_dag)
serialized_dag = json.loads(serialized_dag_json)

assert "params" in serialized_dag["dag"]

@@ -770,14 +769,37 @@ def __init__(self, path: str):
params={'path': S3Param('s3://my_bucket/my_path')},
)

with pytest.raises(SerializationError):
SerializedDAG.to_dict(dag)
@pytest.mark.parametrize(
'param',
[
Param('my value', description='hello', schema={'type': 'string'}),
Param('my value', description='hello'),
Param(None, description=None),
]
)
def test_full_param_roundtrip(self, param):
"""
Test to make sure that only native Param objects are being passed as dag or task params
"""

dag = DAG(dag_id='simple_dag', params={'my_param':param})
serialized_json = SerializedDAG.to_json(dag)
serialized = json.loads(serialized_json)
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["my_param"] == param.value
observed_param = dict.get(dag.params, 'my_param')
assert isinstance(observed_param, Param)
assert observed_param.description == param.description
assert observed_param.schema == param.schema

@pytest.mark.parametrize(
"val, expected_val",
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_task_params_roundtrip(self, val, expected_val):
@@ -1073,56 +1095,56 @@ def test_no_new_fields_added_to_base_operator(self):
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
assert {
'_BaseOperator__instantiated': True,
'_dag': None,
'_downstream_task_ids': set(),
'_inlets': [],
'_log': base_operator.log,
'_outlets': [],
'_upstream_task_ids': set(),
'_pre_execute_hook': None,
'_post_execute_hook': None,
'depends_on_past': False,
'do_xcom_push': True,
'doc': None,
'doc_json': None,
'doc_md': None,
'doc_rst': None,
'doc_yaml': None,
'email': None,
'email_on_failure': True,
'email_on_retry': True,
'end_date': None,
'execution_timeout': None,
'executor_config': {},
'inlets': [],
'label': '10',
'max_active_tis_per_dag': None,
'max_retry_delay': None,
'on_execute_callback': None,
'on_failure_callback': None,
'on_retry_callback': None,
'on_success_callback': None,
'outlets': [],
'owner': 'airflow',
'params': {},
'pool': 'default_pool',
'pool_slots': 1,
'priority_weight': 1,
'queue': 'default',
'resources': None,
'retries': 0,
'retry_delay': timedelta(0, 300),
'retry_exponential_backoff': False,
'run_as_user': None,
'sla': None,
'start_date': None,
'subdag': None,
'task_id': '10',
'trigger_rule': 'all_success',
'wait_for_downstream': False,
'weight_rule': 'downstream',
} == fields, """
'_BaseOperator__instantiated': True,
'_dag': None,
'_downstream_task_ids': set(),
'_inlets': [],
'_log': base_operator.log,
'_outlets': [],
'_upstream_task_ids': set(),
'_pre_execute_hook': None,
'_post_execute_hook': None,
'depends_on_past': False,
'do_xcom_push': True,
'doc': None,
'doc_json': None,
'doc_md': None,
'doc_rst': None,
'doc_yaml': None,
'email': None,
'email_on_failure': True,
'email_on_retry': True,
'end_date': None,
'execution_timeout': None,
'executor_config': {},
'inlets': [],
'label': '10',
'max_active_tis_per_dag': None,
'max_retry_delay': None,
'on_execute_callback': None,
'on_failure_callback': None,
'on_retry_callback': None,
'on_success_callback': None,
'outlets': [],
'owner': 'airflow',
'params': {},
'pool': 'default_pool',
'pool_slots': 1,
'priority_weight': 1,
'queue': 'default',
'resources': None,
'retries': 0,
'retry_delay': timedelta(0, 300),
'retry_exponential_backoff': False,
'run_as_user': None,
'sla': None,
'start_date': None,
'subdag': None,
'task_id': '10',
'trigger_rule': 'all_success',
'wait_for_downstream': False,
'weight_rule': 'downstream',
} == fields, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY
@@ -1439,11 +1461,12 @@ def test_serialized_objects_are_sorted(self, object_to_serialized, expected_outp
assert serialized_obj == expected_output

def test_params_upgrade(self):
"""when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param"""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": __file__,
"fileloc": '__file__',
"tasks": [],
"timezone": "UTC",
"params": {"none": None, "str": "str", "dict": {"a": "b"}},
@@ -1456,12 +1479,15 @@ def test_params_upgrade(self):
assert isinstance(dict.__getitem__(dag.params, "none"), Param)
assert dag.params["str"] == "str"

def test_params_serialize_default(self):
def test_params_serialize_default_2_2_0(self):
"""In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
test only to ensure that params stored in 2.2.0 can still be parsed correctly."""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": __file__,
"fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}},
@@ -1473,6 +1499,31 @@ def test_params_serialize_default(self):
assert isinstance(dict.__getitem__(dag.params, "str"), Param)
assert dag.params["str"] == "str"

def test_params_serialize_default(self):
"""In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
test only to ensure that params stored in 2.2.0 can still be parsed correctly."""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"my_param": {"default": "a string value", "description": "hello",
"schema": {"__var": {"type": "string"}, "__type": "dict"},
"__class": "airflow.models.param.Param"}},
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["my_param"] == "a string value"
param = dict.get(dag.params, 'my_param')
assert isinstance(param, Param)
assert param.description == 'hello'
assert param.schema == {'type': 'string'}


def test_kubernetes_optional():
"""Serialisation / deserialisation continues to work without kubernetes installed"""

0 comments on commit 33ac787

Please sign in to comment.