Skip to content

Commit

Permalink
Fix serialization of Params with set data type (#19267)
Browse files Browse the repository at this point in the history
This is a solution for #19096

Previously, the serialization of params did not run the param value through the `_serialize` function, resulting in non-json-serializable dictionaries.  This manifested when a user, for example, tried to use params with a default value of type `set`.

Here we change the logic to run the param value through the serialization process.  And I add a test for the `set` case.

closes #19096

(cherry picked from commit 8512e05)
  • Loading branch information
dstandish authored and jedcunningham committed Nov 5, 2021
1 parent 75f1d2a commit 157a864
Showing 4 changed files with 125 additions and 23 deletions.
4 changes: 2 additions & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional

import jsonschema
@@ -49,6 +48,7 @@ class Param:
"""

__NO_VALUE_SENTINEL = NoValueSentinel()
CLASS_IDENTIFIER = '__class'

def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs):
self.value = default
@@ -90,7 +90,7 @@ def resolve(self, value: Optional[Any] = __NO_VALUE_SENTINEL, suppress_exception

def dump(self) -> dict:
"""Dump the Param as a dictionary"""
out_dict = {'__class': f'{self.__module__}.{self.__class__.__name__}'}
out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'}
out_dict.update(self.__dict__)
return out_dict

22 changes: 20 additions & 2 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@
"dag": {
"type": "object",
"properties": {
"params": { "$ref": "#/definitions/dict" },
"params": { "$ref": "#/definitions/params_dict" },
"_dag_id": { "type": "string" },
"tasks": { "$ref": "#/definitions/tasks" },
"timezone": { "$ref": "#/definitions/timezone" },
@@ -135,6 +135,24 @@
"type": "array",
"additionalProperties": { "$ref": "#/definitions/operator" }
},
"params_dict": {
"type": "object",
"additionalProperties": {"$ref": "#/definitions/param" }
},
"param": {
"$comment": "A param for a dag / operator",
"type": "object",
"required": [
"__class",
"default"
],
"properties": {
"__class": { "type": "string" },
"default": {},
"description": {"anyOf": [{"type":"string"}, {"type":"null"}]},
"schema": { "$ref": "#/definitions/dict" }
}
},
"operator": {
"$comment": "A task/operator in a DAG",
"type": "object",
@@ -166,7 +184,7 @@
"retry_delay": { "$ref": "#/definitions/timedelta" },
"retry_exponential_backoff": { "type": "boolean" },
"max_retry_delay": { "$ref": "#/definitions/timedelta" },
"params": { "$ref": "#/definitions/dict" },
"params": { "$ref": "#/definitions/params_dict" },
"priority_weight": { "type": "number" },
"weight_rule": { "type": "string" },
"executor_config": { "$ref": "#/definitions/dict" },
50 changes: 38 additions & 12 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,6 @@
except ImportError:
HAS_KUBERNETES = False


if TYPE_CHECKING:
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep

@@ -325,7 +324,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
elif isinstance(var, TaskGroup):
return SerializedTaskGroup.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(var.dump(), type_=DAT.PARAM)
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
@@ -368,9 +367,7 @@ def _deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.TUPLE:
return tuple(cls._deserialize(v) for v in var)
elif type_ == DAT.PARAM:
param_class = import_string(var['_type'])
del var['_type']
return param_class(**var)
return cls._deserialize_param(var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')

@@ -409,30 +406,59 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -
return True
return False

@classmethod
def _serialize_param(cls, param: Param):
return dict(
__class=f"{param.__module__}.{param.__class__.__name__}",
default=cls._serialize(param.value),
description=cls._serialize(param.description),
schema=cls._serialize(param.schema),
)

@classmethod
def _deserialize_param(cls, param_dict: Dict):
"""
In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
this class's ``_serialize`` method. So before running through ``_deserialize``,
we first verify that it's necessary to do.
"""
class_name = param_dict['__class']
class_ = import_string(class_name) # type: Type[Param]
attrs = ('default', 'description', 'schema')
kwargs = {}
for attr in attrs:
if attr not in param_dict:
continue
val = param_dict[attr]
is_serialized = isinstance(val, dict) and '__type' in val
if is_serialized:
deserialized_val = cls._deserialize(param_dict[attr])
kwargs[attr] = deserialized_val
else:
kwargs[attr] = val
return class_(**kwargs)

@classmethod
def _serialize_params_dict(cls, params: ParamsDict):
"""Serialize Params dict for a DAG/Task"""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
kwargs = v.dump()
kwargs['default'] = kwargs.pop('value')
serialized_params[k] = kwargs
serialized_params[k] = cls._serialize_param(v)
else:
raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
return serialized_params

@classmethod
def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
"""Deserialize a DAGs Params dict"""
"""Deserialize a DAG's Params dict"""
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
param_class = import_string(v['__class'])
op_params[k] = param_class(**v)
op_params[k] = cls._deserialize_param(v)
else:
# Old style params, upgrade it
# Old style params, convert it
op_params[k] = Param(v)

return ParamsDict(op_params)
72 changes: 65 additions & 7 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import copy
import importlib
import importlib.util
import json
import multiprocessing
import os
from datetime import datetime, timedelta
@@ -724,6 +725,7 @@ def test_roundtrip_relativedelta(self, val, expected):
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_dag_params_roundtrip(self, val, expected_val):
@@ -733,7 +735,10 @@ def test_dag_params_roundtrip(self, val, expected_val):
dag = DAG(dag_id='simple_dag', params=val)
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
serialized_dag_json = SerializedDAG.to_json(dag)

serialized_dag = json.loads(serialized_dag_json)

assert "params" in serialized_dag["dag"]

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
@@ -764,14 +769,37 @@ def __init__(self, path: str):
params={'path': S3Param('s3://my_bucket/my_path')},
)

with pytest.raises(SerializationError):
SerializedDAG.to_dict(dag)
@pytest.mark.parametrize(
'param',
[
Param('my value', description='hello', schema={'type': 'string'}),
Param('my value', description='hello'),
Param(None, description=None),
],
)
def test_full_param_roundtrip(self, param):
"""
Test to make sure that only native Param objects are being passed as dag or task params
"""

dag = DAG(dag_id='simple_dag', params={'my_param': param})
serialized_json = SerializedDAG.to_json(dag)
serialized = json.loads(serialized_json)
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["my_param"] == param.value
observed_param = dict.get(dag.params, 'my_param')
assert isinstance(observed_param, Param)
assert observed_param.description == param.description
assert observed_param.schema == param.schema

@pytest.mark.parametrize(
"val, expected_val",
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_task_params_roundtrip(self, val, expected_val):
@@ -1433,29 +1461,32 @@ def test_serialized_objects_are_sorted(self, object_to_serialized, expected_outp
assert serialized_obj == expected_output

def test_params_upgrade(self):
"""when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param"""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": __file__,
"fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"none": None, "str": "str", "dict": {"a": "b"}},
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["none"] is None
assert isinstance(dict.__getitem__(dag.params, "none"), Param)
assert dag.params["str"] == "str"

def test_params_serialize_default(self):
def test_params_serialize_default_2_2_0(self):
"""In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
test only to ensure that params stored in 2.2.0 can still be parsed correctly."""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": __file__,
"fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}},
@@ -1467,6 +1498,33 @@ def test_params_serialize_default(self):
assert isinstance(dict.__getitem__(dag.params, "str"), Param)
assert dag.params["str"] == "str"

def test_params_serialize_default(self):
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {
"my_param": {
"default": "a string value",
"description": "hello",
"schema": {"__var": {"type": "string"}, "__type": "dict"},
"__class": "airflow.models.param.Param",
}
},
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["my_param"] == "a string value"
param = dict.get(dag.params, 'my_param')
assert isinstance(param, Param)
assert param.description == 'hello'
assert param.schema == {'type': 'string'}


def test_kubernetes_optional():
"""Serialisation / deserialisation continues to work without kubernetes installed"""

0 comments on commit 157a864

Please sign in to comment.