Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization of Params with set data type #19267

Merged
merged 15 commits into from
Nov 5, 2021
Merged
3 changes: 2 additions & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Param:
"""

__NO_VALUE_SENTINEL = NoValueSentinel()
CLASS_IDENTIFIER = '__class'
uranusjr marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs):
self.value = default
Expand Down Expand Up @@ -90,7 +91,7 @@ def resolve(self, value: Optional[Any] = __NO_VALUE_SENTINEL, suppress_exception

def dump(self) -> dict:
"""Dump the Param as a dictionary"""
out_dict = {'__class': f'{self.__module__}.{self.__class__.__name__}'}
out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'}
out_dict.update(self.__dict__)
return out_dict

Expand Down
29 changes: 20 additions & 9 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
elif isinstance(var, TaskGroup):
return SerializedTaskGroup.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(var.dump(), type_=DAT.PARAM)
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
dstandish marked this conversation as resolved.
Show resolved Hide resolved
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
Expand Down Expand Up @@ -368,9 +368,7 @@ def _deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.TUPLE:
return tuple(cls._deserialize(v) for v in var)
elif type_ == DAT.PARAM:
param_class = import_string(var['_type'])
del var['_type']
return param_class(**var)
return cls._deserialize_param(var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')

Expand Down Expand Up @@ -409,16 +407,30 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -
return True
return False

@classmethod
def _serialize_param(cls, param: Param):
param_dict = param.dump()
param_class_name = param_dict.pop(Param.CLASS_IDENTIFIER)
param_kwargs = dict(default=param_dict.pop('value'), **param_dict)
return {Param.CLASS_IDENTIFIER: param_class_name, **cls._serialize(param_kwargs)}

@classmethod
def _deserialize_param(cls, param_dict: Dict):
param_class = import_string(param_dict.pop(Param.CLASS_IDENTIFIER))
try:
param_kwargs = cls._deserialize(param_dict)
except KeyError:
param_kwargs = param_dict
dstandish marked this conversation as resolved.
Show resolved Hide resolved
return param_class(**param_kwargs)

@classmethod
def _serialize_params_dict(cls, params: ParamsDict):
"""Serialize Params dict for a DAG/Task"""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
kwargs = v.dump()
kwargs['default'] = kwargs.pop('value')
serialized_params[k] = kwargs
serialized_params[k] = cls._serialize_param(v)
else:
raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
return serialized_params
Expand All @@ -429,8 +441,7 @@ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
param_class = import_string(v['__class'])
op_params[k] = param_class(**v)
op_params[k] = cls._deserialize_param(v)
else:
# Old style params, upgrade it
op_params[k] = Param(v)
Expand Down
6 changes: 6 additions & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import copy
import importlib
import importlib.util
import json
import multiprocessing
import os
from datetime import datetime, timedelta
Expand Down Expand Up @@ -724,6 +725,7 @@ def test_roundtrip_relativedelta(self, val, expected):
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_dag_params_roundtrip(self, val, expected_val):
Expand All @@ -734,6 +736,10 @@ def test_dag_params_roundtrip(self, val, expected_val):
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)

# serialized dag dict must be json serializable
json.dumps(serialized_dag)

assert "params" in serialized_dag["dag"]

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
Expand Down