Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTOTYPE NOT TO BE MERGED] Dataset factories prototype #2560

Closed
wants to merge 14 commits into from
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse
pip-tools~=6.5
pluggy~=1.0.0
PyYAML>=4.2, <7.0
Expand Down
61 changes: 61 additions & 0 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""A collection of CLI commands for working with Kedro catalog."""
import copy
from collections import defaultdict
from typing import Iterable

import click
import yaml
from click import secho
from parse import parse

from kedro.framework.cli.utils import KedroCliError, env_option, split_string
from kedro.framework.project import pipelines, settings
Expand Down Expand Up @@ -174,3 +177,61 @@ def _add_missing_datasets_to_catalog(missing_ds, catalog_path):
catalog_path.parent.mkdir(exist_ok=True)
with catalog_path.open(mode="w") as catalog_file:
yaml.safe_dump(catalog_config, catalog_file, default_flow_style=False)


@catalog.command("show")
@env_option
@click.pass_obj
def show_catalog_datasets(metadata: ProjectMetadata, env):
session = _create_session(metadata.package_name, env=env)
context = session.load_context()
catalog_conf = context.config_loader["catalog"]
secho(yaml.dump(catalog_conf))


@catalog.command("resolve")
@env_option
@click.pass_obj
def resolve_catalog_datasets(metadata: ProjectMetadata, env):
session = _create_session(metadata.package_name, env=env)
context = session.load_context()
catalog_conf = context.config_loader["catalog"]

# Create a list of all datasets used in the project pipelines.
pipeline_datasets = []
for pipe in pipelines.keys():
pl_obj = pipelines.get(pipe)
if pl_obj:
pipeline_ds = pl_obj.data_sets()
for ds in pipeline_ds:
pipeline_datasets.append(ds)
else:
existing_pls = ", ".join(sorted(pipelines.keys()))
raise KedroCliError(
f"'{pipe}' pipeline not found! Existing pipelines: {existing_pls}"
)
Comment on lines +201 to +212
Copy link
Member

@idanov idanov May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably shorter and more readable than the original snippet:

Suggested change
pipeline_datasets = []
for pipe in pipelines.keys():
pl_obj = pipelines.get(pipe)
if pl_obj:
pipeline_ds = pl_obj.data_sets()
for ds in pipeline_ds:
pipeline_datasets.append(ds)
else:
existing_pls = ", ".join(sorted(pipelines.keys()))
raise KedroCliError(
f"'{pipe}' pipeline not found! Existing pipelines: {existing_pls}"
)
ds_groups = [pipe.data_sets() for pipe in pipelines.values() if pipe]
pipeline_datasets = [ds for group in ds_groups for ds in group]

Not sure why we need the error?


# Create a copy of the catalog config to not modify the original.
catalog_copy = copy.deepcopy(catalog_conf)
# Loop over all entries in the catalog, find the ones that contain a pattern to be matched,
# loop over al datasets in the pipeline and match these against the patterns.
# Then expand the matches and add them to the catalog copy to display on the CLI.
for ds_name, ds_config in catalog_conf.items():
if "}" in ds_name:
for pipeline_dataset in set(pipeline_datasets):
result = parse(ds_name, pipeline_dataset)
if result:
config_copy = copy.deepcopy(ds_config)
# Match results to patterns in catalog entry
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
config_copy[key] = string_value.format_map(result.named)
catalog_copy[pipeline_dataset] = config_copy

Comment on lines +219 to +231
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this extremely hard to read if not for the comments. I think it would be much better if the code actually does what the comment says, step by step, rather than nesting levels.

Moreover, this complexity seems to have hidden a logical incorrectness of the code: all datasets that appear exactly as they are in the catalog, shouldn't be matched against the patterns (which the original code doesn't do).

Suggested change
for ds_name, ds_config in catalog_conf.items():
if "}" in ds_name:
for pipeline_dataset in set(pipeline_datasets):
result = parse(ds_name, pipeline_dataset)
if result:
config_copy = copy.deepcopy(ds_config)
# Match results to patterns in catalog entry
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
config_copy[key] = string_value.format_map(result.named)
catalog_copy[pipeline_dataset] = config_copy
skip = [ds_name for catalog_conf.keys() if "}" not in ds_name]
datasets = set(pipeline_datasets) - set(skip)
patterns = [ds_name for catalog_conf.keys() if "}" in ds_name]
matches = [(ds, pattern, parse(pattern, ds)) for ds in datasets for pattern in patterns]
matches = [(ds, pattern, result) for ds, pattern, result in matches if result]
for ds, pattern, result in matches:
cfg = copy.deepcopy(catalog_conf[pattern])
# Match results to patterns in catalog entry
for key, value in cfg.items():
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
cfg[key] = string_value.format_map(result.named)
catalog_copy[pipeline_dataset] = cfg

The code for populating the template should probably be a separate function for readability and also to ensure that we can call it recursively (as each value can be a dictionary with other values, etc).

The original code does all the matching and it doesn't stop at the first match, maybe that's intended so we can see all possible matches?

Copy link
Member Author

@merelcht merelcht May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah you can completely ignore this code. As I wrote in the description this is just bonus for the prototype to show what's possible, but in no way meant to be merged. I just went for a quick implementation, because this needs to proper parsing logic and will change anyway.

# Remove all patterns from the resolved catalog
for ds_name, ds_config in catalog_conf.items():
if "}" in ds_name:
del catalog_copy[ds_name]

secho(yaml.dump(catalog_copy))
2 changes: 1 addition & 1 deletion kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class _ProjectSettings(LazySettings):
)

def __init__(self, *args, **kwargs):

kwargs.update(
validators=[
self._CONF_SOURCE,
Expand All @@ -133,6 +132,7 @@ def _load_data_wrapper(func):
"""Wrap a method in _ProjectPipelines so that data is loaded on first access.
Taking inspiration from dynaconf.utils.functional.new_method_proxy
"""

# pylint: disable=protected-access
def inner(self, *args, **kwargs):
self._load_data()
Expand Down
128 changes: 106 additions & 22 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import logging
import re
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Any, Dict, List, Optional, Set, Type, Union, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
data_sets: Dict[str, AbstractDataSet] = None,
feed_dict: Dict[str, Any] = None,
layers: Dict[str, Set[str]] = None,
dataset_patterns: Dict[str, Any] = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -168,6 +171,9 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers
# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self.dataset_patterns = dict(dataset_patterns or {})

# import the feed dict
if feed_dict:
Expand Down Expand Up @@ -255,6 +261,7 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the difference between these two patterns?

catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
Expand All @@ -269,35 +276,52 @@ class to be loaded is specified with the key ``type`` and their

layers: Dict[str, Set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
# Let's assume that any name with } in it is a dataset pattern to be matched.
if "}" in ds_name:
# Add each pattern to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=dataset_patterns,
)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
logging.warning(f"Getting data for {data_set_name}")
if data_set_name not in self._data_sets:
error_msg = f"DataSet '{data_set_name}' not found in the catalog"
# When a dataset is "used" in the pipeline that's not in the recorded catalog datasets,
# try to match it against the patterns in the catalog. If it's a match, resolve it to
# a dataset instance and add it to the catalog, so it only needs to be matched once
# and not everytime the dataset is used in the pipeline.
matched_dataset = self.match_name_against_dataset_factories(data_set_name)
if matched_dataset:
self.add(data_set_name, matched_dataset)
else:
error_msg = f"DataSet '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"
# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"

raise DataSetNotFoundError(error_msg)
raise DataSetNotFoundError(error_msg)

data_set = self._data_sets[data_set_name]
if version and isinstance(data_set, AbstractVersionedDataSet):
Expand Down Expand Up @@ -520,6 +544,35 @@ def add_feed_dict(self, feed_dict: Dict[str, Any], replace: bool = False) -> Non

self.add(data_set_name, data_set, replace)

def match_name_against_dataset_factories(self, dataset_input_name: str) -> Optional[AbstractDataSet]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crazily long name 😅

"""
For a given dataset name, try to match it against the dataset patterns in the catalog.
If it's a match, return the dataset instance.
"""
logging.warning(f"Matching dataset {dataset_input_name}")
dataset = None
# Loop through all dataset patterns and check if the given dataset name has a match.
for dataset_name, dataset_config in self.dataset_patterns.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset_name here is a pattern and the dataset_config is a template/factory, right? We could use that terminology for ease and readability.

result = parse(dataset_name, dataset_input_name)
# If there's a match resolve the rest of the pattern to create a dataset instance.
# A result can be None or something like:
# <Result () {'root_namespace': 'germany', 'dataset_name': 'companies'}>
if result:
config_copy = copy.deepcopy(dataset_config)
# Match results to patterns in catalog entry
for key, value in config_copy.items():
# Find all dataset fields that need to be resolved with
# the values that were matched.
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
# result.named: {'root_namespace': 'germany', 'dataset_name': 'companies'}
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
config_copy[key] = string_value.format_map(result.named)
Comment on lines +561 to +571
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, will make this a function for readability and call it recursively down the line, returning materialised config dictionary.

# Create dataset from catalog config.
dataset = AbstractDataSet.from_config(dataset_name, config_copy)
return dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth caching the match here before returning and adding it to self._data_sets.


def list(self, regex_search: Optional[str] = None) -> List[str]:
"""
List of all ``DataSet`` names registered in the catalog.
Expand Down Expand Up @@ -565,13 +618,44 @@ def list(self, regex_search: Optional[str] = None) -> List[str]:
) from exc
return [dset_name for dset_name in self._data_sets if pattern.search(dset_name)]

def exists_in_catalog(self, dataset_name: str) -> bool:
"""Check if a dataset exists in the catalog as an exact match or if it matches a pattern."""
if dataset_name in self._data_sets:
return True

if self.dataset_patterns and any(
parse(pattern, dataset_name) for pattern in self.dataset_patterns
):
return True
return False

def remove_pattern_matches(self, dataset_list: Set[str]):
"""Helper method that checks which dataset names match a pattern in the catalog.
It returns a copy of the original list minus all those matched dataset names."""
if self.dataset_patterns:
dataset_list_minus_matched = []
for dataset in dataset_list:
# If dataset matches a pattern, remove it from the list.
for dataset_name in self.dataset_patterns.keys():
result = parse(dataset_name, dataset)
if result:
break
else:
dataset_list_minus_matched.append(dataset)
return set(dataset_list_minus_matched)
return dataset_list
Comment on lines +635 to +646
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much nicer:

Suggested change
if self.dataset_patterns:
dataset_list_minus_matched = []
for dataset in dataset_list:
# If dataset matches a pattern, remove it from the list.
for dataset_name in self.dataset_patterns.keys():
result = parse(dataset_name, dataset)
if result:
break
else:
dataset_list_minus_matched.append(dataset)
return set(dataset_list_minus_matched)
return dataset_list
if not self.dataset_patterns:
return dataset_list
dataset_list_minus_matched = dataset_list.copy()
for dataset in dataset_list:
matches = (parse(pattern, dataset) for pattern in self.dataset_patterns.keys())
# If dataset matches any pattern, remove it from the list.
if any(matches):
dataset_list_minus_matched -= dataset
return set(dataset_list_minus_matched)

As a general rule:

  1. The shorter case in an if statement should always go first
  2. If there's return, there's no need of else (like in the original code)
  3. There's usually a nicer solution than using a break in a loop (not always though)


def shallow_copy(self) -> "DataCatalog":
"""Returns a shallow copy of the current object.

Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets,
layers=self.layers,
dataset_patterns=self.dataset_patterns,
)

def __eq__(self, other):
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
Expand Down
25 changes: 21 additions & 4 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,31 @@ def run(
hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = [
input_name
for input_name in pipeline.inputs()
if not catalog.exists_in_catalog(input_name)
]
if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog.
free_outputs = [
output_name
for output_name in pipeline.outputs()
if not catalog.exists_in_catalog(output_name)
]

# Check which datasets used in the pipeline aren't in the catalog and don't match
# a pattern in the catalog and create a default dataset for those datasets.
unregistered_ds = [
ds for ds in pipeline.data_sets() if not catalog.exists_in_catalog(ds)
]
logging.warning(f"UNREGISTERED DS: {unregistered_ds}")
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down