From 6af245756f6c065cb3a509a51427158422ce72e0 Mon Sep 17 00:00:00 2001 From: Philippe Moussalli Date: Mon, 26 Feb 2024 19:04:15 +0100 Subject: [PATCH 1/2] add struct to types --- src/fondant/core/schema.py | 59 +++++++++++++++++++++++++++++++++++--- tests/core/test_schema.py | 34 ++++++++++++++++++++++ 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/fondant/core/schema.py b/src/fondant/core/schema.py index 953cd916..7108fa4c 100644 --- a/src/fondant/core/schema.py +++ b/src/fondant/core/schema.py @@ -78,6 +78,7 @@ def get_path(self): "date64": pa.date64(), "duration": pa.duration("us"), "string": pa.string(), + "struct": pa.struct([]), "utf8": pa.utf8(), "binary": pa.binary(), "large_binary": pa.large_binary(), @@ -135,6 +136,31 @@ def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type": pa.list_(data_type.value if isinstance(data_type, Type) else data_type), ) + @classmethod + def struct( + cls, + fields: t.List[t.Tuple[str, t.Union[str, pa.DataType, "Type"]]], + ) -> "Type": + """ + Creates a new `Type` instance representing a struct with the specified fields. + + Args: + fields: A list of tuples where each tuple contains the name and type of a field. + + Returns: + A new `Type` instance representing a struct with the specified fields. + """ + validated_fields = [] + for name, data_type in fields: + if isinstance(data_type, Type): + type_ = data_type.value + elif isinstance(data_type, pa.DataType): + type_ = data_type + else: + type_ = cls._validate_data_type(data_type) + validated_fields.append(pa.field(name, type_)) + return cls(pa.struct(validated_fields)) + @classmethod def from_dict(cls, json_schema: dict): """ @@ -147,13 +173,35 @@ def from_dict(cls, json_schema: dict): Returns: A new `Type` instance representing the specified data type. """ - if json_schema["type"] == "array": - items = json_schema["items"] + type_name = json_schema.get("type") + + if type_name is None: + msg = "Invalid or missing 'type' key in the schema." + raise InvalidTypeSchema(msg) + + if type_name == "array": + items = json_schema.get("items") if isinstance(items, dict): return cls.list(cls.from_dict(items)) - return None + if isinstance(items, str): + return cls.list(items) + + msg = "Invalid 'items' type in array schema." + raise InvalidTypeSchema(msg) - return cls(json_schema["type"]) + if type_name == "object": + properties = json_schema.get("properties") + if not isinstance(properties, dict): + msg = "Invalid 'properties' type in object schema." + raise InvalidTypeSchema(msg) + fields = [(name, cls.from_dict(prop)) for name, prop in properties.items()] + return cls.struct(fields) + + if isinstance(type_name, str): + return cls(type_name) + + msg = f"Invalid 'type' value: {type_name}" + raise InvalidTypeSchema(msg) def to_dict(self) -> dict: """ @@ -166,6 +214,9 @@ def to_dict(self) -> dict: items = self.value.value_type if isinstance(items, pa.DataType): return {"type": "array", "items": Type(items).to_dict()} + elif isinstance(self.value, pa.StructType): + fields = [(field.name, Type(field.type).to_dict()) for field in self.value] + return {"type": "object", "properties": dict(fields)} type_ = None for type_name, data_type in _TYPES.items(): diff --git a/tests/core/test_schema.py b/tests/core/test_schema.py index f442835c..c1d9df34 100644 --- a/tests/core/test_schema.py +++ b/tests/core/test_schema.py @@ -15,6 +15,22 @@ def test_valid_type(): "type": "array", "items": {"type": "float32"}, } + assert Type.struct( + [ + ("f1", "int32"), + ("f2", "string"), + ("f3", Type.list("int8")), + ("f4", Type.struct([("f5", "int32")])), + ], + ).to_dict() == { + "type": "object", + "properties": { + "f1": {"type": "int32"}, + "f2": {"type": "string"}, + "f3": {"type": "array", "items": {"type": "int8"}}, + "f4": {"type": "object", "properties": {"f5": {"type": "int32"}}}, + }, + } def test_valid_json_schema(): @@ -26,6 +42,24 @@ def test_valid_json_schema(): assert Type.from_dict( {"type": "array", "items": {"type": "array", "items": {"type": "int8"}}}, ).value == pa.list_(pa.list_(pa.int8())) + assert Type.from_dict( + { + "type": "object", + "properties": { + "f1": {"type": "int32"}, + "f2": {"type": "string"}, + "f3": {"type": "array", "items": {"type": "int8"}}, + "f4": {"type": "object", "properties": {"f5": {"type": "int32"}}}, + }, + }, + ) == Type.struct( + [ + ("f1", "int32"), + ("f2", "string"), + ("f3", Type.list("int8")), + ("f4", Type.struct([("f5", "int32")])), + ], + ) @pytest.mark.parametrize( From e99c580fb265ae43afc05cf2f020ea21575d735c Mon Sep 17 00:00:00 2001 From: Philippe Moussalli Date: Tue, 27 Feb 2024 15:58:03 +0100 Subject: [PATCH 2/2] Add struct/object to component schema --- src/fondant/core/schemas/common.json | 15 +++++++++++++-- .../examples/component_specs/component.yaml | 19 +++++++++++++++++++ tests/component/test_component.py | 17 ++++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/fondant/core/schemas/common.json b/src/fondant/core/schemas/common.json index a88e889e..b00eb8a0 100644 --- a/src/fondant/core/schemas/common.json +++ b/src/fondant/core/schemas/common.json @@ -27,7 +27,8 @@ "binary", "list", "struct", - "array" + "array", + "object" ] }, "field": { @@ -37,6 +38,16 @@ "type": "string", "$ref": "#/definitions/subset_data_type" }, + "properties": { + "type": "object", + "properties": { + "type": { + "type": "string", + "$ref": "#/definitions/subset_data_type" + } + }, + "additionalProperties": true + }, "items": { "oneOf": [ { @@ -76,4 +87,4 @@ } } } -} \ No newline at end of file +} diff --git a/tests/component/examples/component_specs/component.yaml b/tests/component/examples/component_specs/component.yaml index d1f28b76..43d1f221 100644 --- a/tests/component/examples/component_specs/component.yaml +++ b/tests/component/examples/component_specs/component.yaml @@ -12,6 +12,25 @@ produces: items: type: float32 + element: + type: object + properties: + id: + type: string + number: + type: int32 + + elements: + type: array + items: + type: object + properties: + id: + type: string + number: + type: int32 + + args: flag: description: user argument diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 75d4dbe6..75aa7662 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -52,7 +52,22 @@ def _patched_data_loading(monkeypatch): """Mock data loading so no actual data is loaded.""" def mocked_load_dataframe(self): - return dd.from_dict({"images_data": [1, 2, 3]}, npartitions=N_PARTITIONS) + return dd.from_dict( + { + "images_data": [1, 2, 3], + "element": [ + ("1", 1), + ("2", 2), + ("3", 3), + ], + "elements": [ + [("1", 1), ("2", 2), ("3", 3)], + [("4", 4), ("5", 5), ("6", 6)], + [("7", 7), ("8", 8), ("9", 9)], + ], + }, + npartitions=N_PARTITIONS, + ) monkeypatch.setattr(DaskDataLoader, "load_dataframe", mocked_load_dataframe)