Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add struct to types #879

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions src/fondant/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_path(self):
"date64": pa.date64(),
"duration": pa.duration("us"),
"string": pa.string(),
"struct": pa.struct([]),
"utf8": pa.utf8(),
"binary": pa.binary(),
"large_binary": pa.large_binary(),
Expand Down Expand Up @@ -135,6 +136,31 @@ def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type":
pa.list_(data_type.value if isinstance(data_type, Type) else data_type),
)

@classmethod
def struct(
cls,
fields: t.List[t.Tuple[str, t.Union[str, pa.DataType, "Type"]]],
) -> "Type":
"""
Creates a new `Type` instance representing a struct with the specified fields.

Args:
fields: A list of tuples where each tuple contains the name and type of a field.

Returns:
A new `Type` instance representing a struct with the specified fields.
"""
validated_fields = []
for name, data_type in fields:
if isinstance(data_type, Type):
type_ = data_type.value
elif isinstance(data_type, pa.DataType):
type_ = data_type
else:
type_ = cls._validate_data_type(data_type)
validated_fields.append(pa.field(name, type_))
return cls(pa.struct(validated_fields))

@classmethod
def from_dict(cls, json_schema: dict):
"""
Expand All @@ -147,13 +173,35 @@ def from_dict(cls, json_schema: dict):
Returns:
A new `Type` instance representing the specified data type.
"""
if json_schema["type"] == "array":
items = json_schema["items"]
type_name = json_schema.get("type")

if type_name is None:
msg = "Invalid or missing 'type' key in the schema."
raise InvalidTypeSchema(msg)

if type_name == "array":
items = json_schema.get("items")
if isinstance(items, dict):
return cls.list(cls.from_dict(items))
return None
if isinstance(items, str):
return cls.list(items)

msg = "Invalid 'items' type in array schema."
raise InvalidTypeSchema(msg)

return cls(json_schema["type"])
if type_name == "object":
properties = json_schema.get("properties")
if not isinstance(properties, dict):
msg = "Invalid 'properties' type in object schema."
raise InvalidTypeSchema(msg)
fields = [(name, cls.from_dict(prop)) for name, prop in properties.items()]
return cls.struct(fields)

if isinstance(type_name, str):
return cls(type_name)

msg = f"Invalid 'type' value: {type_name}"
raise InvalidTypeSchema(msg)

def to_dict(self) -> dict:
"""
Expand All @@ -166,6 +214,9 @@ def to_dict(self) -> dict:
items = self.value.value_type
if isinstance(items, pa.DataType):
return {"type": "array", "items": Type(items).to_dict()}
elif isinstance(self.value, pa.StructType):
fields = [(field.name, Type(field.type).to_dict()) for field in self.value]
return {"type": "object", "properties": dict(fields)}

type_ = None
for type_name, data_type in _TYPES.items():
Expand Down
15 changes: 13 additions & 2 deletions src/fondant/core/schemas/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"binary",
"list",
"struct",
"array"
"array",
"object"
]
},
"field": {
Expand All @@ -37,6 +38,16 @@
"type": "string",
"$ref": "#/definitions/subset_data_type"
},
"properties": {
"type": "object",
"properties": {
"type": {
"type": "string",
"$ref": "#/definitions/subset_data_type"
}
},
"additionalProperties": true
},
"items": {
"oneOf": [
{
Expand Down Expand Up @@ -76,4 +87,4 @@
}
}
}
}
}
19 changes: 19 additions & 0 deletions tests/component/examples/component_specs/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@ produces:
items:
type: float32

element:
type: object
properties:
id:
type: string
number:
type: int32

elements:
type: array
items:
type: object
properties:
id:
type: string
number:
type: int32


args:
flag:
description: user argument
Expand Down
17 changes: 16 additions & 1 deletion tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,22 @@ def _patched_data_loading(monkeypatch):
"""Mock data loading so no actual data is loaded."""

def mocked_load_dataframe(self):
return dd.from_dict({"images_data": [1, 2, 3]}, npartitions=N_PARTITIONS)
return dd.from_dict(
{
"images_data": [1, 2, 3],
"element": [
("1", 1),
("2", 2),
("3", 3),
],
"elements": [
[("1", 1), ("2", 2), ("3", 3)],
[("4", 4), ("5", 5), ("6", 6)],
[("7", 7), ("8", 8), ("9", 9)],
],
},
npartitions=N_PARTITIONS,
)

monkeypatch.setattr(DaskDataLoader, "load_dataframe", mocked_load_dataframe)

Expand Down
34 changes: 34 additions & 0 deletions tests/core/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ def test_valid_type():
"type": "array",
"items": {"type": "float32"},
}
assert Type.struct(
[
("f1", "int32"),
("f2", "string"),
("f3", Type.list("int8")),
("f4", Type.struct([("f5", "int32")])),
],
).to_dict() == {
"type": "object",
"properties": {
"f1": {"type": "int32"},
"f2": {"type": "string"},
"f3": {"type": "array", "items": {"type": "int8"}},
"f4": {"type": "object", "properties": {"f5": {"type": "int32"}}},
},
}


def test_valid_json_schema():
Expand All @@ -26,6 +42,24 @@ def test_valid_json_schema():
assert Type.from_dict(
{"type": "array", "items": {"type": "array", "items": {"type": "int8"}}},
).value == pa.list_(pa.list_(pa.int8()))
assert Type.from_dict(
{
"type": "object",
"properties": {
"f1": {"type": "int32"},
"f2": {"type": "string"},
"f3": {"type": "array", "items": {"type": "int8"}},
"f4": {"type": "object", "properties": {"f5": {"type": "int32"}}},
},
},
) == Type.struct(
[
("f1", "int32"),
("f2", "string"),
("f3", Type.list("int8")),
("f4", Type.struct([("f5", "int32")])),
],
)


@pytest.mark.parametrize(
Expand Down
Loading