Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: pop None inputs specified in overrides #1022

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/aiida_quantumespresso/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""Decorators for several purposes."""


def remove_none_overrides(func):
"""Remove namespaces of the returned builder of a `get_builder*` method."""

def recursively_remove_nones(item):
"""Recursively remove keys with None values from dictionaries."""
if isinstance(item, dict):
return {key: recursively_remove_nones(value) for key, value in item.items() if value is not None}
return item

def remove_keys_from_builder(builder, keys, path=()):
"""Recursively remove specified keys from the builder based on a path."""
if not keys:
return
current_level = keys.pop(0)
if hasattr(builder, current_level):
if keys:
next_attr = getattr(builder, current_level)
remove_keys_from_builder(next_attr, keys, path + (current_level,))
else:
delattr(builder, current_level)

def wrapper(*args, **kwargs):
"""Wrap the function."""
if 'overrides' in kwargs and kwargs['overrides'] is not None:
original_overrides = kwargs['overrides']

# Identify paths to keys with None values to be removed
paths_to_remove = []

def find_paths(item, path=()):
"""Find the paths to remove."""
if isinstance(item, dict):
for key, value in item.items():
if value is None:
paths_to_remove.append(path + (key,))
else:
find_paths(value, path + (key,))

find_paths(original_overrides)

# Recursively remove keys with None values from overrides
cleaned_overrides = recursively_remove_nones(original_overrides)
kwargs['overrides'] = cleaned_overrides

# Call the original function to get the builder
builder = func(*args, **kwargs)

# Remove specified keys from the builder
for path in paths_to_remove:
remove_keys_from_builder(builder, list(path))

return builder

return func(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiida.engine import ToContext, WorkChain, if_

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.utils.decorators import remove_none_overrides
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
Expand Down Expand Up @@ -120,6 +121,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'bands.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.

Expand Down
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
from aiida_quantumespresso.utils.decorators import remove_none_overrides
from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults

from ..protocols.utils import ProtocolMixin
Expand Down Expand Up @@ -103,6 +104,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'base.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(
cls,
code,
Expand Down
26 changes: 26 additions & 0 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure):
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'relax': {'base_final_scf': None}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member

overrides = {'relax': None}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'relax' not in builder # pylint: disable=no-member

overrides = {'relax': {'base': {'pw': {'parameters': {'SYSTEM': {'ecutwfc': None}}}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member

overrides = {'relax': {'base': {'pw': {'parameters': None}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member
11 changes: 11 additions & 0 deletions tests/workflows/protocols/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure):

assert metadata['options']['queue_name'] == queue_name
assert metadata['options']['withmpi'] == withmpi


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'kpoints_distance': None}
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'kpoints_distance' not in builder # pylint: disable=no-member
Loading