Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More generic configuration filter #244

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e9b768d
FilterOutliers overhaul
Tetracarbonylnickel Dec 11, 2023
3110a6b
Filter Node becomes childclass of ConfigurationSelection
Tetracarbonylnickel Dec 11, 2023
90fe016
introduced test for filter selection
Tetracarbonylnickel Dec 11, 2023
66c6251
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Dec 11, 2023
813a0c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
7f278a3
_get_plot() fix
Tetracarbonylnickel Dec 12, 2023
8309a89
Merge branch 'more-generic-strucktures-filter' of https://github.com/…
Tetracarbonylnickel Dec 12, 2023
113989a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
21c57b7
Ragged values fix
Tetracarbonylnickel Dec 14, 2023
401177a
node name change
Tetracarbonylnickel Dec 14, 2023
7c11a67
Merge branch 'more-generic-strucktures-filter' of https://github.com/…
Tetracarbonylnickel Dec 14, 2023
2f1a994
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2023
2168bd3
fexed integration test
Tetracarbonylnickel Dec 14, 2023
97e88fe
Merge branch 'more-generic-strucktures-filter' of https://github.com/…
Tetracarbonylnickel Dec 14, 2023
4404c6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2023
4b30b61
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Dec 20, 2023
1be8116
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Dec 22, 2023
770dca3
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Apr 18, 2024
8832668
PropertyFilter fix
Tetracarbonylnickel Apr 19, 2024
02bdaf7
ThresholdSelection fix
Tetracarbonylnickel Apr 19, 2024
ae4800b
node fix
Tetracarbonylnickel Apr 19, 2024
36b45e0
test fix
Tetracarbonylnickel Apr 19, 2024
8ccb5b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
ba19cb1
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Apr 19, 2024
51accca
cleanup
Tetracarbonylnickel Apr 22, 2024
f443082
Merge remote-tracking branch 'origin/main' into more-generic-strucktu…
Tetracarbonylnickel Apr 22, 2024
8d3332d
Merge branch 'more-generic-strucktures-filter' of https://github.com/…
Tetracarbonylnickel Apr 22, 2024
35c35d5
fix integration test
Tetracarbonylnickel Apr 22, 2024
df77d79
Merge branch 'main' into more-generic-strucktures-filter
Tetracarbonylnickel Jul 1, 2024
75d00b5
fix filter test
Tetracarbonylnickel Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ipsuite/configuration_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Configuration Selection Nodes."""

from ipsuite.configuration_selection.base import ConfigurationSelection
from ipsuite.configuration_selection.filter import FilterOutlier
from ipsuite.configuration_selection.filter import PropertyFilter
from ipsuite.configuration_selection.index import IndexSelection
from ipsuite.configuration_selection.kernel import KernelSelection
from ipsuite.configuration_selection.random import RandomSelection
Expand All @@ -21,5 +21,5 @@
"IndexSelection",
"ThresholdSelection",
"SplitSelection",
"FilterOutlier",
"PropertyFilter",
]
115 changes: 78 additions & 37 deletions ipsuite/configuration_selection/filter.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,111 @@
import typing as t

import ase
import matplotlib.pyplot as plt
import numpy as np
import zntrack

from ipsuite import base
from ipsuite.configuration_selection import ConfigurationSelection


class FilterOutlier(base.ProcessAtoms):
"""Remove outliers from the data based on a given property.
def direct_cutoff(values, threshold, cutoffs):
# Filtering the direct cutoff values
if cutoffs is None:
raise ValueError(
"cutoffs have to be specified for using the direct cutoff filter."
)
return (cutoffs[0], cutoffs[1])


def cutoff_around_mean(values, threshold, cutoffs):
# Filtering in multiples of the standard deviation around the mean.
mean = np.mean(values)
std = np.std(values)

upper_cutoff = mean + threshold * std
lower_cutoff = mean - threshold * std
return (lower_cutoff, upper_cutoff)


CUTOFF = {"direct": direct_cutoff, "around_mean": cutoff_around_mean}


class PropertyFilter(ConfigurationSelection):
"""Filter structures from the dataset based on a given property.

Tetracarbonylnickel marked this conversation as resolved.
Show resolved Hide resolved
Attributes
----------
key : str, default="energy"
The property to filter on.
threshold : float, default=3
The threshold for filtering in units of standard deviations.
cutoff_type : {"direct", "around_mean"}, default="around_mean"
Defines the cutoff type.
direction : {"above", "below", "both"}, default="both"
The direction to filter in.
threshold : float, default=3
The threshold for filtering in units of standard deviations.
cutoffs : list(float), default=None
Lower and upper cutoff.
"""

key: str = zntrack.params("energy")
threshold: float = zntrack.params(3)
cutoff_type: t.Literal["direct", "around_mean"] = zntrack.params("around_mean")
direction: t.Literal["above", "below", "both"] = zntrack.params("both")
threshold: float = zntrack.params(3)
cutoffs: t.Union[t.List[float], None] = zntrack.params(None)

def select_atoms(self, atoms_lst: t.List[ase.Atoms]) -> t.List[int]:
values = [atoms.calc.results[self.key] for atoms in atoms_lst]

filtered_indices: list = zntrack.outs()
histogram: str = zntrack.outs_path(zntrack.nwd / "histogram.png")
# get maximal atomic value per struckture
if isinstance(values[0], np.ndarray):
if values[0].ndim == 2:
# calculates the maximal magnetude of atomic cartesian property
values = [
np.max(np.linalg.norm(value, axis=1), axis=0) for value in values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify with plotting

]
elif values[0].ndim == 1:
# calculates the maximal atomic property
values = [np.max(value, axis=0) for value in values]

def run(self):
values = [x.calc.results[self.key] for x in self.data]
mean = np.mean(values)
std = np.std(values)
lower_limit, upper_limit = CUTOFF[self.cutoff_type](
values,
self.threshold,
self.cutoffs,
)

if self.direction == "above":
self.filtered_indices = [
i for i, x in enumerate(values) if x > mean + self.threshold * std
]
selection = [i for i, x in enumerate(values) if x > upper_limit]
elif self.direction == "below":
self.filtered_indices = [
i for i, x in enumerate(values) if x < mean - self.threshold * std
]
selection = [i for i, x in enumerate(values) if x < lower_limit]
else:
self.filtered_indices = [
i
for i, x in enumerate(values)
if x > mean + self.threshold * std or x < mean - self.threshold * std
selection = [
i for i, x in enumerate(values) if x > lower_limit and x < upper_limit
]

return selection

def _get_plot(self, atoms_lst: t.List[ase.Atoms], indices: t.List[int]):
values = [atoms.calc.results[self.key] for atoms in atoms_lst]

# get maximal atomic value per struckture
if isinstance(values[0], np.ndarray):
if values[0].ndim == 2:
# calculates the maximal magnetude of atomic cartesian property
values = [
np.max(np.linalg.norm(value, axis=1), axis=0) for value in values
]
elif values[0].ndim == 1:
# calculates the maximal atomic property
values = [np.max(value, axis=0) for value in values]

fig, ax = plt.subplots(3, figsize=(10, 10))
ax[0].hist(values, bins=100)
ax[0].set_title("All")
ax[1].hist(
[values[i] for i in range(len(values)) if i not in self.filtered_indices],
ax[1].hist([values[i] for i in indices], bins=100)
ax[1].set_title("Selected")
ax[2].hist(
[values[i] for i in range(len(values)) if i not in indices],
bins=100,
)
ax[1].set_title("Filtered")
ax[2].hist([values[i] for i in self.filtered_indices], bins=100)
ax[2].set_title("Excluded")
fig.savefig(self.histogram, bbox_inches="tight")

@property
def atoms(self):
return [
self.data[i] for i in range(len(self.data)) if i not in self.filtered_indices
]

@property
def excluded_atoms(self):
return [self.data[i] for i in self.filtered_indices]
fig.savefig(self.img_selection, bbox_inches="tight")
2 changes: 1 addition & 1 deletion ipsuite/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class _Nodes:
)
UniformTemporalSelection = "ipsuite.configuration_selection.UniformTemporalSelection"
ThresholdSelection = "ipsuite.configuration_selection.ThresholdSelection"
FilterOutlier = "ipsuite.configuration_selection.FilterOutlier"
PropertyFilter = "ipsuite.configuration_selection.PropertyFilter"
BatchKernelSelection = "ipsuite.models.apax.BatchKernelSelection"

# Configuration Comparison
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/configuration_selection/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,12 @@ def test_exclude_configurations_list(proj_path, traj_file):
def test_filter_outlier(proj_path, traj_file):
with ips.Project() as project:
data = ips.AddData(file=traj_file)
filtered_data = ips.configuration_selection.FilterOutlier(
data=data.atoms, key="energy", threshold=1, direction="both"
filtered_data = ips.configuration_selection.PropertyFilter(
data=data.atoms,
key="energy",
cutoff_type="around_mean",
threshold=1,
direction="both",
)

project.run()
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/configuration_selection/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import pytest

from ipsuite.configuration_selection import PropertyFilter


@pytest.mark.parametrize(
"key, cutoff_type, direction, cutoffs",
[
("forces", "direct", "both", [7, 13]),
Tetracarbonylnickel marked this conversation as resolved.
Show resolved Hide resolved
("forces", "direct", "both", None),
("forces", "around_mean", "both", None),
],
)
def test_get_selected_atoms(atoms_list, key, cutoff_type, direction, cutoffs):
for idx, atoms in enumerate(atoms_list):
atoms.calc.results[key] = np.array([[idx, 0, 0], [0, 0, 0]])

filter = PropertyFilter(
key=key,
cutoff_type=cutoff_type,
direction=direction,
data=None,
cutoffs=cutoffs,
threshold=0.4,
)

if "direct" in cutoff_type and cutoffs is None:
with pytest.raises(ValueError):
selected_atoms = filter.select_atoms(atoms_list)
else:
test_selection = [8, 9, 10, 11, 12]
selected_atoms = filter.select_atoms(atoms_list)
assert isinstance(selected_atoms, list)
assert len(set(selected_atoms)) == 5
assert selected_atoms == test_selection
Loading