Skip to content

Commit

Permalink
Add struct to types (#879)
Browse files Browse the repository at this point in the history
PR that adds `Struct` to list of valid types in Dask dataframe. Required
for a client use case to ingest data with a certain format (documents
with nested named attributes)
  • Loading branch information
PhilippeMoussalli authored Feb 28, 2024
1 parent 88c6ea8 commit 6640d88
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 7 deletions.
59 changes: 55 additions & 4 deletions src/fondant/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_path(self):
"date64": pa.date64(),
"duration": pa.duration("us"),
"string": pa.string(),
"struct": pa.struct([]),
"utf8": pa.utf8(),
"binary": pa.binary(),
"large_binary": pa.large_binary(),
Expand Down Expand Up @@ -135,6 +136,31 @@ def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type":
pa.list_(data_type.value if isinstance(data_type, Type) else data_type),
)

@classmethod
def struct(
cls,
fields: t.List[t.Tuple[str, t.Union[str, pa.DataType, "Type"]]],
) -> "Type":
"""
Creates a new `Type` instance representing a struct with the specified fields.
Args:
fields: A list of tuples where each tuple contains the name and type of a field.
Returns:
A new `Type` instance representing a struct with the specified fields.
"""
validated_fields = []
for name, data_type in fields:
if isinstance(data_type, Type):
type_ = data_type.value
elif isinstance(data_type, pa.DataType):
type_ = data_type
else:
type_ = cls._validate_data_type(data_type)
validated_fields.append(pa.field(name, type_))
return cls(pa.struct(validated_fields))

@classmethod
def from_dict(cls, json_schema: dict):
"""
Expand All @@ -147,13 +173,35 @@ def from_dict(cls, json_schema: dict):
Returns:
A new `Type` instance representing the specified data type.
"""
if json_schema["type"] == "array":
items = json_schema["items"]
type_name = json_schema.get("type")

if type_name is None:
msg = "Invalid or missing 'type' key in the schema."
raise InvalidTypeSchema(msg)

if type_name == "array":
items = json_schema.get("items")
if isinstance(items, dict):
return cls.list(cls.from_dict(items))
return None
if isinstance(items, str):
return cls.list(items)

msg = "Invalid 'items' type in array schema."
raise InvalidTypeSchema(msg)

return cls(json_schema["type"])
if type_name == "object":
properties = json_schema.get("properties")
if not isinstance(properties, dict):
msg = "Invalid 'properties' type in object schema."
raise InvalidTypeSchema(msg)
fields = [(name, cls.from_dict(prop)) for name, prop in properties.items()]
return cls.struct(fields)

if isinstance(type_name, str):
return cls(type_name)

msg = f"Invalid 'type' value: {type_name}"
raise InvalidTypeSchema(msg)

def to_dict(self) -> dict:
"""
Expand All @@ -166,6 +214,9 @@ def to_dict(self) -> dict:
items = self.value.value_type
if isinstance(items, pa.DataType):
return {"type": "array", "items": Type(items).to_dict()}
elif isinstance(self.value, pa.StructType):
fields = [(field.name, Type(field.type).to_dict()) for field in self.value]
return {"type": "object", "properties": dict(fields)}

type_ = None
for type_name, data_type in _TYPES.items():
Expand Down
15 changes: 13 additions & 2 deletions src/fondant/core/schemas/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"binary",
"list",
"struct",
"array"
"array",
"object"
]
},
"field": {
Expand All @@ -37,6 +38,16 @@
"type": "string",
"$ref": "#/definitions/subset_data_type"
},
"properties": {
"type": "object",
"properties": {
"type": {
"type": "string",
"$ref": "#/definitions/subset_data_type"
}
},
"additionalProperties": true
},
"items": {
"oneOf": [
{
Expand Down Expand Up @@ -76,4 +87,4 @@
}
}
}
}
}
19 changes: 19 additions & 0 deletions tests/component/examples/component_specs/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@ produces:
items:
type: float32

element:
type: object
properties:
id:
type: string
number:
type: int32

elements:
type: array
items:
type: object
properties:
id:
type: string
number:
type: int32


args:
flag:
description: user argument
Expand Down
17 changes: 16 additions & 1 deletion tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,22 @@ def _patched_data_loading(monkeypatch):
"""Mock data loading so no actual data is loaded."""

def mocked_load_dataframe(self):
return dd.from_dict({"images_data": [1, 2, 3]}, npartitions=N_PARTITIONS)
return dd.from_dict(
{
"images_data": [1, 2, 3],
"element": [
("1", 1),
("2", 2),
("3", 3),
],
"elements": [
[("1", 1), ("2", 2), ("3", 3)],
[("4", 4), ("5", 5), ("6", 6)],
[("7", 7), ("8", 8), ("9", 9)],
],
},
npartitions=N_PARTITIONS,
)

monkeypatch.setattr(DaskDataLoader, "load_dataframe", mocked_load_dataframe)

Expand Down
34 changes: 34 additions & 0 deletions tests/core/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ def test_valid_type():
"type": "array",
"items": {"type": "float32"},
}
assert Type.struct(
[
("f1", "int32"),
("f2", "string"),
("f3", Type.list("int8")),
("f4", Type.struct([("f5", "int32")])),
],
).to_dict() == {
"type": "object",
"properties": {
"f1": {"type": "int32"},
"f2": {"type": "string"},
"f3": {"type": "array", "items": {"type": "int8"}},
"f4": {"type": "object", "properties": {"f5": {"type": "int32"}}},
},
}


def test_valid_json_schema():
Expand All @@ -26,6 +42,24 @@ def test_valid_json_schema():
assert Type.from_dict(
{"type": "array", "items": {"type": "array", "items": {"type": "int8"}}},
).value == pa.list_(pa.list_(pa.int8()))
assert Type.from_dict(
{
"type": "object",
"properties": {
"f1": {"type": "int32"},
"f2": {"type": "string"},
"f3": {"type": "array", "items": {"type": "int8"}},
"f4": {"type": "object", "properties": {"f5": {"type": "int32"}}},
},
},
) == Type.struct(
[
("f1", "int32"),
("f2", "string"),
("f3", Type.list("int8")),
("f4", Type.struct([("f5", "int32")])),
],
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6640d88

Please sign in to comment.