diff --git a/jupytext/metadata_filter.py b/jupytext/metadata_filter.py index 996ef53ca..c3f28af0f 100644 --- a/jupytext/metadata_filter.py +++ b/jupytext/metadata_filter.py @@ -1,45 +1,45 @@ """Notebook and cell metadata filtering""" -def parse_metadata_config(metadata_config): +def parse_metadata_config(metadata_config, actual_keys, filtered_keys=None): """Return additional and excluded sets that correspond to a config of the form 'entry_one,entry_two-negative_entry_one,negative_entry_two""" - if not metadata_config: - return set(), set() - - if isinstance(metadata_config, list): - return set(metadata_config), set() - - if isinstance(metadata_config, tuple): - additional, excluded = metadata_config - return set(additional), set(excluded) + if metadata_config is True: + metadata_config = 'all' + elif metadata_config is False: + metadata_config = '-all' + elif metadata_config is None: + metadata_config = '' if '-' in metadata_config: additional, excluded = metadata_config.split('-', 1) - additional = additional.split(',') - excluded = excluded.split(',') - return set(additional), set(excluded) + excluded = set(excluded.split(',')).difference({''}) + if not additional and 'all' not in excluded: + additional = set(filtered_keys or actual_keys) + else: + additional = set(additional.split(',')).difference({''}) + else: + additional = set(metadata_config.split(',')).difference({''}) + excluded = set() + + if 'all' in additional: + additional = actual_keys + if 'all' in excluded: + excluded = actual_keys.difference(additional) - additional = metadata_config.split(',') - return set(additional), set() + return additional, excluded def filter_metadata(metadata, user_metadata_config, default_metadata_config): """Filter the cell or notebook metadata, according to the user preference""" - if user_metadata_config is True: - return metadata - - if user_metadata_config is False: - return {} - actual_keys = set(metadata.keys()) - default_positive, default_negative = parse_metadata_config(default_metadata_config) - user_positive, user_negative = parse_metadata_config(user_metadata_config) + default_positive, default_negative = parse_metadata_config(default_metadata_config, actual_keys) + user_positive, user_negative = parse_metadata_config( + user_metadata_config, actual_keys, + actual_keys.intersection(default_positive).difference(default_negative)) - if not default_positive: - keep_keys = actual_keys.difference(default_negative.difference(user_positive)).difference(user_negative) - else: - keep_keys = actual_keys.intersection(default_positive.union(user_positive)).difference(user_negative) + keep_keys = actual_keys.intersection(default_positive.difference(user_negative).union(user_positive)) \ + .difference(default_negative.difference(user_positive).union(user_negative)) for key in actual_keys: if key not in keep_keys: diff --git a/tests/test_metadata_filter.py b/tests/test_metadata_filter.py new file mode 100644 index 000000000..a91b2742e --- /dev/null +++ b/tests/test_metadata_filter.py @@ -0,0 +1,28 @@ +from jupytext.metadata_filter import filter_metadata + + +def to_dict(keys): + return {key: None for key in keys} + + +def test_metadata_filter_default(): + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), None, '-technical' + ) == to_dict(['user', 'preserve']) + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), None, 'preserve-all' + ) == to_dict(['preserve']) + + +def test_metadata_filter_user_plus_default(): + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), '-user', '-technical' + ) == to_dict(['preserve']) + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), 'all-user', '-technical' + ) == to_dict(['preserve', 'technical']) + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), 'user', 'preserve-all' + ) == to_dict(['user', 'preserve']) + + +def test_metadata_filter_user_overrides_default(): + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), 'all-user', '-technical' + ) == to_dict(['technical', 'preserve']) + assert filter_metadata(to_dict(['technical', 'user', 'preserve']), 'user-all', 'preserve' + ) == to_dict(['user'])