diff --git a/src/aiida_quantumespresso/utils/decorators.py b/src/aiida_quantumespresso/utils/decorators.py new file mode 100644 index 00000000..61a9fcb0 --- /dev/null +++ b/src/aiida_quantumespresso/utils/decorators.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +"""Decorators for several purposes.""" + + +def remove_none_overrides(func): + """Remove namespaces of the returned builder of a `get_builder*` method.""" + + def recursively_remove_nones(item): + """Recursively remove keys with None values from dictionaries.""" + if isinstance(item, dict): + return {key: recursively_remove_nones(value) for key, value in item.items() if value is not None} + return item + + def remove_keys_from_builder(builder, keys, path=()): + """Recursively remove specified keys from the builder based on a path.""" + if not keys: + return + current_level = keys.pop(0) + if hasattr(builder, current_level): + if keys: + next_attr = getattr(builder, current_level) + remove_keys_from_builder(next_attr, keys, path + (current_level,)) + else: + delattr(builder, current_level) + + def wrapper(*args, **kwargs): + """Wrap the function.""" + if 'overrides' in kwargs and kwargs['overrides'] is not None: + original_overrides = kwargs['overrides'] + + # Identify paths to keys with None values to be removed + paths_to_remove = [] + + def find_paths(item, path=()): + """Find the paths to remove.""" + if isinstance(item, dict): + for key, value in item.items(): + if value is None: + paths_to_remove.append(path + (key,)) + else: + find_paths(value, path + (key,)) + + find_paths(original_overrides) + + # Recursively remove keys with None values from overrides + cleaned_overrides = recursively_remove_nones(original_overrides) + kwargs['overrides'] = cleaned_overrides + + # Call the original function to get the builder + builder = func(*args, **kwargs) + + # Remove specified keys from the builder + for path in paths_to_remove: + remove_keys_from_builder(builder, list(path)) + + return builder + + return func(*args, **kwargs) + + return wrapper diff --git a/src/aiida_quantumespresso/workflows/pw/bands.py b/src/aiida_quantumespresso/workflows/pw/bands.py index d4d0a32e..41d2f267 100644 --- a/src/aiida_quantumespresso/workflows/pw/bands.py +++ b/src/aiida_quantumespresso/workflows/pw/bands.py @@ -5,6 +5,7 @@ from aiida.engine import ToContext, WorkChain, if_ from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis +from aiida_quantumespresso.utils.decorators import remove_none_overrides from aiida_quantumespresso.utils.mapping import prepare_process_inputs from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain @@ -120,6 +121,7 @@ def get_protocol_filepath(cls): return files(pw_protocols) / 'bands.yaml' @classmethod + @remove_none_overrides def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs): """Return a builder prepopulated with inputs selected according to the chosen protocol. diff --git a/src/aiida_quantumespresso/workflows/pw/base.py b/src/aiida_quantumespresso/workflows/pw/base.py index 4a3ab038..66121fe0 100644 --- a/src/aiida_quantumespresso/workflows/pw/base.py +++ b/src/aiida_quantumespresso/workflows/pw/base.py @@ -8,6 +8,7 @@ from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType +from aiida_quantumespresso.utils.decorators import remove_none_overrides from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults from ..protocols.utils import ProtocolMixin @@ -103,6 +104,7 @@ def get_protocol_filepath(cls): return files(pw_protocols) / 'base.yaml' @classmethod + @remove_none_overrides def get_builder_from_protocol( cls, code, diff --git a/tests/workflows/protocols/pw/test_bands.py b/tests/workflows/protocols/pw/test_bands.py index 9ffbf0c6..2ab7d77f 100644 --- a/tests/workflows/protocols/pw/test_bands.py +++ b/tests/workflows/protocols/pw/test_bands.py @@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure): builder.bands.pw.metadata, # pylint: disable=no-member ): assert subspace['options']['queue_name'] == queue_name, subspace + + +def test_pop_none_overrides(fixture_code, generate_structure): + """Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method.""" + code = fixture_code('quantumespresso.pw') + structure = generate_structure() + + overrides = {'relax': {'base_final_scf': None}} + builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + + assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member + + overrides = {'relax': None} + builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + + assert 'relax' not in builder # pylint: disable=no-member + + overrides = {'relax': {'base': {'pw': {'parameters': {'SYSTEM': {'ecutwfc': None}}}}}} + builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + + assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member + + overrides = {'relax': {'base': {'pw': {'parameters': None}}}} + builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + + assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member diff --git a/tests/workflows/protocols/pw/test_base.py b/tests/workflows/protocols/pw/test_base.py index d4b398c6..2376ee5b 100644 --- a/tests/workflows/protocols/pw/test_base.py +++ b/tests/workflows/protocols/pw/test_base.py @@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure): assert metadata['options']['queue_name'] == queue_name assert metadata['options']['withmpi'] == withmpi + + +def test_pop_none_overrides(fixture_code, generate_structure): + """Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method.""" + code = fixture_code('quantumespresso.pw') + structure = generate_structure() + + overrides = {'kpoints_distance': None} + builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + + assert 'kpoints_distance' not in builder # pylint: disable=no-member