Skip to content

Commit

Permalink
Enable defining nested data types (#193)
Browse files Browse the repository at this point in the history
PR that addresses #178 
Inspiration:
https://swagger.io/docs/specification/data-models/data-types/#:~:text=the%20null%20value.-,Arrays,-Arrays%20are%20defined
(credit to @GeorgesLorre)

It includes: 
* Changing the common json schema to enable defining array types 
* Changing the `Type` class to be dynamic rather than typed to enable
defining nested structures without explicitly defining them
* Adding an addition test file for `schema.py`
* Updating the current existing examples (embedding + segmentation)
  • Loading branch information
PhilippeMoussalli authored Jun 9, 2023
1 parent ee5bd30 commit 2ce1b0e
Show file tree
Hide file tree
Showing 24 changed files with 290 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ consumes:
embeddings:
fields:
data:
type: float32_list
type: array
items:
type: float32

produces:
images:
Expand Down
4 changes: 3 additions & 1 deletion components/image_embedding/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ produces:
embeddings:
fields:
data:
type: float32_list
type: array
items:
type: float32

args:
model_id:
Expand Down
4 changes: 3 additions & 1 deletion components/segment_images/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ produces:
segmentations:
fields:
data:
type: binary
type: array
items:
type: binary

args:
model_id:
Expand Down
20 changes: 15 additions & 5 deletions docs/component_spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ produces:
embeddings:
fields:
data:
type: int8_list
type: array
items:
type: float32
...
```

Expand Down Expand Up @@ -220,7 +222,9 @@ produces:
embeddings:
fields:
data:
type: binary
type: array
items:
type: float32
```

</td>
Expand Down Expand Up @@ -326,7 +330,9 @@ produces:
embeddings:
fields:
data:
type: binary
type: array
items:
type: float32
```

</td>
Expand Down Expand Up @@ -425,7 +431,9 @@ produces:
embeddings:
fields:
data:
type: binary
type: array
items:
type: float32
```

</td>
Expand Down Expand Up @@ -516,7 +524,9 @@ produces:
embeddings:
fields:
data:
type: binary
type: array
items:
type: float32
additionalSubsets: false
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ consumes:
segmentations:
fields:
data:
type: binary
type: array
items:
type: binary

args:
hf_token:
Expand Down
2 changes: 1 addition & 1 deletion fondant/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __repr__(self) -> str:
def fields(self) -> t.Mapping[str, Field]:
return types.MappingProxyType(
{
name: Field(name=name, type=Type[field["type"]])
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["fields"].items()
}
)
Expand Down
4 changes: 4 additions & 0 deletions fondant/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ class InvalidComponentSpec(ValidationError, FondantException):

class InvalidPipelineDefinition(ValidationError, FondantException):
"""Thrown when a pipeline definition is invalid."""


class InvalidTypeSchema(ValidationError, FondantException):
"""Thrown when a Type schema definition is invalid."""
10 changes: 5 additions & 5 deletions fondant/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fields(self) -> t.Mapping[str, Field]:
"""The fields of the subset returned as an immutable mapping."""
return types.MappingProxyType(
{
name: Field(name=name, type=Type[field["type"]])
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["fields"].items()
}
)
Expand All @@ -47,7 +47,7 @@ def add_field(self, name: str, type_: Type, *, overwrite: bool = False) -> None:
if not overwrite and name in self._specification["fields"]:
raise ValueError(f"A field with name {name} already exists")

self._specification["fields"][name] = {"type": type_.name}
self._specification["fields"][name] = type_.to_json()

def remove_field(self, name: str) -> None:
del self._specification["fields"][name]
Expand All @@ -62,8 +62,8 @@ class Index(Subset):
@property
def fields(self) -> t.Dict[str, Field]:
return {
"id": Field(name="id", type=Type.string),
"source": Field(name="source", type=Type.string),
"id": Field(name="id", type=Type("string")),
"source": Field(name="source", type=Type("string")),
}


Expand Down Expand Up @@ -178,7 +178,7 @@ def add_subset(

self._specification["subsets"][name] = {
"location": f"/{name}/{self.run_id}/{self.component_id}",
"fields": {name: {"type": type_.name} for name, type_ in fields},
"fields": {name: type_.to_json() for name, type_ in fields},
}

def remove_subset(self, name: str) -> None:
Expand Down
175 changes: 133 additions & 42 deletions fondant/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,147 @@
and pipelines.
"""

import enum
import typing as t

import pyarrow as pa

KubeflowCommandArguments = t.List[t.Union[str, t.Dict[str, str]]]

from fondant.exceptions import InvalidTypeSchema

class Type(enum.Enum):
"""Supported types.
KubeflowCommandArguments = t.List[t.Union[str, t.Dict[str, str]]]

Based on:
- https://arrow.apache.org/docs/python/api/datatypes.html#api-types
- https://pola-rs.github.io/polars/py-polars/html/reference/datatypes.html
"""
Types based on:
- https://arrow.apache.org/docs/python/api/datatypes.html#api-types
"""
_TYPES: t.Dict[str, pa.DataType] = {
"null": pa.null(),
"bool": pa.bool_(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float64": pa.float64(),
"decimal128": pa.decimal128(38),
"time32": pa.time32("s"),
"time64": pa.time64("us"),
"timestamp": pa.timestamp("us"),
"date32": pa.date32(),
"date64": pa.date64(),
"duration": pa.duration("us"),
"string": pa.string(),
"utf8": pa.utf8(),
"binary": pa.binary(),
"large_binary": pa.large_binary(),
"large_utf8": pa.large_utf8(),
}


class Type:
"""
The `Type` class provides a way to define and validate data types for various purposes. It
supports different data types including primitive types and complex types like lists.
"""

bool = pa.bool_()

int8 = pa.int8()
int16 = pa.int16()
int32 = pa.int32()
int64 = pa.int64()

uint8 = pa.uint8()
uint16 = pa.uint16()
uint32 = pa.uint32()
uint64 = pa.uint64()

float16 = pa.float16()
float32 = pa.float32()
float64 = pa.float64()

decimal = pa.decimal128(38)

time32 = pa.time32("s")
time64 = pa.time64("us")
timestamp = pa.timestamp("us")

date32 = pa.date32()
date64 = pa.date64()
duration = pa.duration("us")

string = pa.string()
utf8 = pa.utf8()

binary = pa.binary()

int8_list = pa.list_(pa.int8())

float32_list = pa.list_(pa.float32())
def __init__(self, data_type: t.Union[str, pa.DataType]):
self.value = self._validate_data_type(data_type)

@staticmethod
def _validate_data_type(data_type: t.Union[str, pa.DataType]) -> pa.DataType:
"""
Validates the provided data type and returns the corresponding data type object.
Args:
data_type: The data type to validate.
Returns:
The validated `pa.DataType` object.
"""
if not isinstance(data_type, (Type, pa.DataType)):
try:
data_type = _TYPES[data_type]
except KeyError:
raise InvalidTypeSchema(
f"Invalid schema provided {data_type} with type {type(data_type)}."
f" Current available data types are: {_TYPES.keys()}"
)
return data_type

@classmethod
def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type":
"""
Creates a new `Type` instance representing a list of the specified data type.
Args:
data_type: The data type for the list elements. It can be a string representing the
data type or an existing `pa.DataType` object.
Returns:
A new `Type` instance representing a list of the specified data type.
"""
data_type = cls._validate_data_type(data_type)
return cls(
pa.list_(data_type.value if isinstance(data_type, Type) else data_type)
)

@classmethod
def from_json(cls, json_schema: dict):
"""
Creates a new `Type` instance based on a dictionary representation of the json schema
of a data type (https://swagger.io/docs/specification/data-models/data-types/).
Args:
json_schema: The dictionary representation of the data type, can represent nested values
Returns:
A new `Type` instance representing the specified data type.
"""
if json_schema["type"] == "array":
items = json_schema["items"]
if isinstance(items, dict):
return cls.list(cls.from_json(items))
else:
return cls(json_schema["type"])

def to_json(self) -> dict:
"""
Converts the `Type` instance to its JSON representation.
Returns:
A dictionary representing the JSON schema of the data type.
"""
if isinstance(self.value, pa.ListType):
items = self.value.value_type
if isinstance(items, pa.DataType):
return {"type": "array", "items": Type(items).to_json()}

type_ = None
for type_name, data_type in _TYPES.items():
if self.value.equals(data_type):
type_ = type_name
break

return {"type": type_}

@property
def name(self):
"""Name of the data type."""
return str(self.value)

def __repr__(self):
"""Returns a string representation of the `Type` instance."""
return f"Type({repr(self.value)})"

def __eq__(self, other):
if isinstance(other, Type):
return self.value == other.value

return False


class Field(t.NamedTuple):
Expand Down
21 changes: 16 additions & 5 deletions fondant/schemas/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
"binary",
"list",
"struct",
"int8_list",
"float32_list"
"array"
]
},
"field": {
Expand All @@ -36,11 +35,23 @@
"type": {
"type": "string",
"$ref": "#/definitions/subset_data_type"
},
"items": {
"oneOf": [
{
"$ref": "#/definitions/field"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/field"
}
}
]
}
},
"required": [
"type"
]
"required": ["type"],
"additionalProperties": false
},
"fields": {
"type": "object",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ produces:
embeddings:
fields:
data:
type: binary
type: array
items:
type: float32

args:
storage_args:
Expand Down
Loading

0 comments on commit 2ce1b0e

Please sign in to comment.