From 33ac787f89943fe7b2b19c566abdaa3b01758256 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 29 Oct 2021 07:43:49 -0700 Subject: [PATCH] be more explicit about ser and deser; add tests --- airflow/serialization/serialized_objects.py | 36 ++-- tests/serialization/test_dag_serialization.py | 167 ++++++++++++------ 2 files changed, 132 insertions(+), 71 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index ebc473390b08a..f973d56529ac6 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -55,7 +55,6 @@ except ImportError: HAS_KUBERNETES = False - if TYPE_CHECKING: from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -409,19 +408,30 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) - @classmethod def _serialize_param(cls, param: Param): - param_dict = param.dump() - param_class_name = param_dict.pop(Param.CLASS_IDENTIFIER) - param_kwargs = dict(default=param_dict.pop('value'), **param_dict) - return {Param.CLASS_IDENTIFIER: param_class_name, **cls._serialize(param_kwargs)} + return dict( + __class=f"{param.__module__}.{param.__class__.__name__}", + default=cls._serialize(param.value), + description=cls._serialize(param.description), + schema=cls._serialize(param.schema), + ) @classmethod def _deserialize_param(cls, param_dict: Dict): - param_class = import_string(param_dict.pop(Param.CLASS_IDENTIFIER)) - try: - param_kwargs = cls._deserialize(param_dict) - except KeyError: - param_kwargs = param_dict - return param_class(**param_kwargs) + class_name = param_dict['__class'] + class_ = import_string(class_name) # type: Type[Param] + attrs = ('default', 'description', 'schema') + kwargs = {} + for attr in attrs: + if attr not in param_dict: + continue + val = param_dict[attr] + is_serialized = isinstance(val, dict) and '__type' in val + if is_serialized: + deserialized_val = cls._deserialize(param_dict[attr]) + kwargs[attr] = deserialized_val + else: + kwargs[attr] = val + return class_(**kwargs) @classmethod def _serialize_params_dict(cls, params: ParamsDict): @@ -437,13 +447,13 @@ def _serialize_params_dict(cls, params: ParamsDict): @classmethod def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict: - """Deserialize a DAGs Params dict""" + """Deserialize a DAG's Params dict""" op_params = {} for k, v in encoded_params.items(): if isinstance(v, dict) and "__class" in v: op_params[k] = cls._deserialize_param(v) else: - # Old style params, upgrade it + # Old style params, convert it op_params[k] = Param(v) return ParamsDict(op_params) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 2ef092dc1b3aa..99e208db48996 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -735,10 +735,9 @@ def test_dag_params_roundtrip(self, val, expected_val): dag = DAG(dag_id='simple_dag', params=val) BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag_json = SerializedDAG.to_json(dag) - # serialized dag dict must be json serializable - json.dumps(serialized_dag) + serialized_dag = json.loads(serialized_dag_json) assert "params" in serialized_dag["dag"] @@ -770,14 +769,37 @@ def __init__(self, path: str): params={'path': S3Param('s3://my_bucket/my_path')}, ) - with pytest.raises(SerializationError): - SerializedDAG.to_dict(dag) + @pytest.mark.parametrize( + 'param', + [ + Param('my value', description='hello', schema={'type': 'string'}), + Param('my value', description='hello'), + Param(None, description=None), + ] + ) + def test_full_param_roundtrip(self, param): + """ + Test to make sure that only native Param objects are being passed as dag or task params + """ + + dag = DAG(dag_id='simple_dag', params={'my_param':param}) + serialized_json = SerializedDAG.to_json(dag) + serialized = json.loads(serialized_json) + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + + assert dag.params["my_param"] == param.value + observed_param = dict.get(dag.params, 'my_param') + assert isinstance(observed_param, Param) + assert observed_param.description == param.description + assert observed_param.schema == param.schema @pytest.mark.parametrize( "val, expected_val", [ (None, {}), ({"param_1": "value_1"}, {"param_1": "value_1"}), + ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}), ], ) def test_task_params_roundtrip(self, val, expected_val): @@ -1073,56 +1095,56 @@ def test_no_new_fields_added_to_base_operator(self): base_operator = BaseOperator(task_id="10") fields = base_operator.__dict__ assert { - '_BaseOperator__instantiated': True, - '_dag': None, - '_downstream_task_ids': set(), - '_inlets': [], - '_log': base_operator.log, - '_outlets': [], - '_upstream_task_ids': set(), - '_pre_execute_hook': None, - '_post_execute_hook': None, - 'depends_on_past': False, - 'do_xcom_push': True, - 'doc': None, - 'doc_json': None, - 'doc_md': None, - 'doc_rst': None, - 'doc_yaml': None, - 'email': None, - 'email_on_failure': True, - 'email_on_retry': True, - 'end_date': None, - 'execution_timeout': None, - 'executor_config': {}, - 'inlets': [], - 'label': '10', - 'max_active_tis_per_dag': None, - 'max_retry_delay': None, - 'on_execute_callback': None, - 'on_failure_callback': None, - 'on_retry_callback': None, - 'on_success_callback': None, - 'outlets': [], - 'owner': 'airflow', - 'params': {}, - 'pool': 'default_pool', - 'pool_slots': 1, - 'priority_weight': 1, - 'queue': 'default', - 'resources': None, - 'retries': 0, - 'retry_delay': timedelta(0, 300), - 'retry_exponential_backoff': False, - 'run_as_user': None, - 'sla': None, - 'start_date': None, - 'subdag': None, - 'task_id': '10', - 'trigger_rule': 'all_success', - 'wait_for_downstream': False, - 'weight_rule': 'downstream', - } == fields, """ + '_BaseOperator__instantiated': True, + '_dag': None, + '_downstream_task_ids': set(), + '_inlets': [], + '_log': base_operator.log, + '_outlets': [], + '_upstream_task_ids': set(), + '_pre_execute_hook': None, + '_post_execute_hook': None, + 'depends_on_past': False, + 'do_xcom_push': True, + 'doc': None, + 'doc_json': None, + 'doc_md': None, + 'doc_rst': None, + 'doc_yaml': None, + 'email': None, + 'email_on_failure': True, + 'email_on_retry': True, + 'end_date': None, + 'execution_timeout': None, + 'executor_config': {}, + 'inlets': [], + 'label': '10', + 'max_active_tis_per_dag': None, + 'max_retry_delay': None, + 'on_execute_callback': None, + 'on_failure_callback': None, + 'on_retry_callback': None, + 'on_success_callback': None, + 'outlets': [], + 'owner': 'airflow', + 'params': {}, + 'pool': 'default_pool', + 'pool_slots': 1, + 'priority_weight': 1, + 'queue': 'default', + 'resources': None, + 'retries': 0, + 'retry_delay': timedelta(0, 300), + 'retry_exponential_backoff': False, + 'run_as_user': None, + 'sla': None, + 'start_date': None, + 'subdag': None, + 'task_id': '10', + 'trigger_rule': 'all_success', + 'wait_for_downstream': False, + 'weight_rule': 'downstream', + } == fields, """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY @@ -1439,11 +1461,12 @@ def test_serialized_objects_are_sorted(self, object_to_serialized, expected_outp assert serialized_obj == expected_output def test_params_upgrade(self): + """when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param""" serialized = { "__version": 1, "dag": { "_dag_id": "simple_dag", - "fileloc": __file__, + "fileloc": '__file__', "tasks": [], "timezone": "UTC", "params": {"none": None, "str": "str", "dict": {"a": "b"}}, @@ -1456,12 +1479,15 @@ def test_params_upgrade(self): assert isinstance(dict.__getitem__(dag.params, "none"), Param) assert dag.params["str"] == "str" - def test_params_serialize_default(self): + def test_params_serialize_default_2_2_0(self): + """In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though + the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this + test only to ensure that params stored in 2.2.0 can still be parsed correctly.""" serialized = { "__version": 1, "dag": { "_dag_id": "simple_dag", - "fileloc": __file__, + "fileloc": '/path/to/file.py', "tasks": [], "timezone": "UTC", "params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}}, @@ -1473,6 +1499,31 @@ def test_params_serialize_default(self): assert isinstance(dict.__getitem__(dag.params, "str"), Param) assert dag.params["str"] == "str" + def test_params_serialize_default(self): + """In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though + the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this + test only to ensure that params stored in 2.2.0 can still be parsed correctly.""" + serialized = { + "__version": 1, + "dag": { + "_dag_id": "simple_dag", + "fileloc": '/path/to/file.py', + "tasks": [], + "timezone": "UTC", + "params": {"my_param": {"default": "a string value", "description": "hello", + "schema": {"__var": {"type": "string"}, "__type": "dict"}, + "__class": "airflow.models.param.Param"}}, + }, + } + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + + assert dag.params["my_param"] == "a string value" + param = dict.get(dag.params, 'my_param') + assert isinstance(param, Param) + assert param.description == 'hello' + assert param.schema == {'type': 'string'} + def test_kubernetes_optional(): """Serialisation / deserialisation continues to work without kubernetes installed"""