Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dataset factory resolve nested dict properly #2993

Merged
merged 16 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

## Major features and improvements
## Bug fixes and other changes
* Updated dataset factories to resolve nested catalog config properly.

## Documentation changes
## Breaking changes to the API
## Upcoming deprecations for Kedro 0.19.0
Expand Down
18 changes: 16 additions & 2 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A collection of CLI commands for working with Kedro catalog."""
import copy
from collections import defaultdict
from itertools import chain

Expand Down Expand Up @@ -84,7 +85,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
factory_ds_by_type[ds_config["type"]].append(ds_name)

default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values()))
Expand Down Expand Up @@ -244,7 +251,14 @@ def resolve_patterns(metadata: ProjectMetadata, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)

ds_config["filepath"] = _trim_filepath(
str(context.project_path) + "/", ds_config["filepath"]
)
Expand Down
41 changes: 25 additions & 16 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import re
from collections import defaultdict
from typing import Any, Dict, Iterable
from typing import Any, Dict

from parse import parse

Expand Down Expand Up @@ -388,7 +388,10 @@ def _get_dataset(
if data_set_name not in self._data_sets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
data_set_config = self._resolve_config(data_set_name, matched_pattern)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
data_set_config = self._resolve_config(
data_set_name, matched_pattern, config_copy
)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
self.layers = self.layers or {}
Expand Down Expand Up @@ -436,27 +439,33 @@ def __contains__(self, data_set_name):
return True
return False

@classmethod
def _resolve_config(
self,
cls,
data_set_name: str,
matched_pattern: str,
config: dict,
) -> dict[str, Any]:
"""Get resolved AbstractDataset from a factory config"""
result = parse(matched_pattern, data_set_name)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = str(value).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
return config_copy
if isinstance(config, dict):
for key, value in config.items():
config[key] = cls._resolve_config(data_set_name, matched_pattern, value)
elif isinstance(config, (list, tuple)):
config = [
cls._resolve_config(data_set_name, matched_pattern, value)
for value in config
]
elif isinstance(config, str) and "}" in config:
try:
config = str(config).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the config "
f"should be present in the dataset factory pattern."
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
) from exc
return config

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.
Expand Down
38 changes: 37 additions & 1 deletion tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def config_with_dataset_factories():
}


@pytest.fixture
def config_with_dataset_factories_nested():
return {
"catalog": {
"{brand}_cars": {
"type": "PartitionedDataset",
"path": "data/01_raw",
"dataset": "pandas.CSVDataSet",
"metadata": {
"my-plugin": {
"brand": "{brand}",
"list_config": [
"NA",
"{brand}",
],
"nested_list_dict": [{}, {"brand": "{brand}"}],
}
},
},
},
}


@pytest.fixture
def config_with_dataset_factories_with_default(config_with_dataset_factories):
config_with_dataset_factories["catalog"]["{default_dataset}"] = {
Expand Down Expand Up @@ -840,7 +863,10 @@ def test_unmatched_key_error_when_parsing_config(
):
"""Check error raised when key mentioned in the config is not in pattern name"""
catalog = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern)
pattern = "Unable to resolve 'filepath' for the pattern '{type}@planes'"
pattern = (
"Unable to resolve 'data/01_raw/{brand}_plane.pq' from the pattern '{type}@planes'. "
"Keys used in the config should be present in the dataset factory pattern."
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
)
with pytest.raises(DatasetError, match=re.escape(pattern)):
catalog._get_dataset("jet@planes")

Expand Down Expand Up @@ -896,3 +922,13 @@ def test_factory_config_versioned(
microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None
)
assert actual_timestamp == expected_timestamp

def test_factory_nested_config(self, config_with_dataset_factories_nested):
catalog = DataCatalog.from_config(**config_with_dataset_factories_nested)
dataset = catalog._get_dataset("tesla_cars")
assert dataset.metadata["my-plugin"]["brand"] == "tesla"
assert dataset.metadata["my-plugin"]["list_config"] == ["NA", "tesla"]
assert dataset.metadata["my-plugin"]["nested_list_dict"] == [
{},
{"brand": "tesla"},
]