Skip to content

Commit

Permalink
Merge feature/implement-new-dataset-format into feature/refactore-com…
Browse files Browse the repository at this point in the history
…ponent-package
  • Loading branch information
mrchtr committed Nov 22, 2023
2 parents 365ca6d + a60ca3e commit d2182a0
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 18 deletions.
6 changes: 1 addition & 5 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,13 @@ def image(self, value: str) -> None:
def tags(self) -> t.List[str]:
return self._specification.get("tags", None)

@property
def index(self):
return Field(name="index", location=self._specification["index"].location)

@property
def consumes(self) -> t.Mapping[str, Field]:
"""The fields consumed by the component as an immutable mapping."""
return types.MappingProxyType(
{
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification.get("produces", {}).items()
for name, field in self._specification.get("consumes", {}).items()
},
)

Expand Down
17 changes: 6 additions & 11 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections import OrderedDict
from dataclasses import asdict, dataclass
from pathlib import Path
from types import MappingProxyType

import jsonschema.exceptions
from fsspec import open as fs_open
Expand Down Expand Up @@ -173,7 +172,7 @@ def field_mapping(self) -> t.Mapping[str, t.List[str]]:
}
"""
field_mapping = {}
for field_name, field in self.fields.items():
for field_name, field in {"Index": self.index, **self.fields}.items():
location = (
f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}"
)
Expand All @@ -182,18 +181,14 @@ def field_mapping(self) -> t.Mapping[str, t.List[str]]:
else:
field_mapping[location] = [field_name]


# Sort field mapping that the first dataset contains the index
index_location = (
f"{self.base_path}/{self.pipeline_name}/{self.run_id}{self.index.location}"
)
sorted_keys = sorted(
field_mapping.keys(), key=lambda key: index_location == key, reverse=True
)
sorted_keys = sorted(field_mapping.keys(), key=lambda key: "Index" in field_mapping[key], reverse=True)
sorted_field_mapping = OrderedDict(
(key, field_mapping[key]) for key in sorted_keys
)

return MappingProxyType(sorted_field_mapping)
return types.MappingProxyType(sorted_field_mapping)

@property
def run_id(self) -> str:
Expand Down Expand Up @@ -238,7 +233,7 @@ def add_or_update_field(self, field: Field, overwrite: bool = False):
else:
self._specification["fields"][field.name] = {
"location": f"/{self.component_id}",
"type": field.type.name,
"type": field.type.to_json(),
}

def _add_or_update_index(self, field: Field, overwrite: bool = True):
Expand Down Expand Up @@ -301,7 +296,7 @@ def evolve( # noqa : PLR0912 (too many branches)

# Add or update all produced fields defined in the component spec
for name, field in component_spec.produces.items():
# If field was part not part of the input manifest, add field to output manifest.
# If field was not part of the input manifest, add field to output manifest.
# If field was part of the input manifest and got produced by the component, update
# the manifest field.
evolved_manifest.add_or_update_field(field, overwrite=True)
Expand Down
1 change: 0 additions & 1 deletion src/fondant/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def __eq__(self, other):
return False


@dataclass
class Field:
"""Class representing a single field or column in a Fondant dataset."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
"location":"/example_component"
},
"embeddings_data": {
"type": "list<item: float>",
"type": {
"type": "array",
"items": {
"type": "float32"
}
},
"location":"/example_component"
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_field_mapping(valid_manifest):
assert field_mapping == OrderedDict(
{
"gs://bucket/test_pipeline/test_pipeline_12345/component2": [
"Index",
"height",
"width",
],
Expand Down

0 comments on commit d2182a0

Please sign in to comment.