Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization of Params with set data type #19267

Merged
merged 15 commits into from
Nov 5, 2021
Merged
25 changes: 16 additions & 9 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
elif isinstance(var, TaskGroup):
return SerializedTaskGroup.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(var.dump(), type_=DAT.PARAM)
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
dstandish marked this conversation as resolved.
Show resolved Hide resolved
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
Expand Down Expand Up @@ -368,9 +368,7 @@ def _deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.TUPLE:
return tuple(cls._deserialize(v) for v in var)
elif type_ == DAT.PARAM:
param_class = import_string(var['_type'])
del var['_type']
return param_class(**var)
return cls._deserialize_param(var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')

Expand Down Expand Up @@ -409,16 +407,26 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -
return True
return False

@classmethod
def _serialize_param(cls, param: Param):
d = param.dump()
d['value'] = cls._serialize(d['value'])
return d

@classmethod
def _deserialize_param(cls, param: Dict):
param_class = import_string(param['__class'])
param['value'] = cls._deserialize(param['value'])
return param_class(default=param['value'], description=param['description'])
uranusjr marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _serialize_params_dict(cls, params: ParamsDict):
"""Serialize Params dict for a DAG/Task"""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
kwargs = v.dump()
kwargs['default'] = kwargs.pop('value')
serialized_params[k] = kwargs
serialized_params[k] = cls._serialize_param(v)
else:
raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
return serialized_params
Expand All @@ -429,8 +437,7 @@ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
param_class = import_string(v['__class'])
op_params[k] = param_class(**v)
op_params[k] = cls._deserialize_param(v)
else:
# Old style params, upgrade it
op_params[k] = Param(v)
Expand Down