diff --git a/CHANGELOG.md b/CHANGELOG.md index b75c6722c..597318b25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ _ `_optional` subpackage for managing optional dependencies - `register_hooks` utility enabling user-defined augmentation of arbitrary callables - `transform` methods of `SearchSpace`, `SubspaceDiscrete` and `SubspaceContinuous` now take additional `allow_missing` and `allow_extra` keyword arguments +- Utilities for permutation and dependency data augmentation ### Changed - Passing an `Objective` to `Campaign` is now optional diff --git a/baybe/constraints/base.py b/baybe/constraints/base.py index 3509c27f0..00797d085 100644 --- a/baybe/constraints/base.py +++ b/baybe/constraints/base.py @@ -36,6 +36,10 @@ class Constraint(ABC, SerialMixin): eval_during_modeling: ClassVar[bool] """Class variable encoding whether the condition is evaluated during modeling.""" + eval_during_augmentation: ClassVar[bool] = False + """Class variable encoding whether the constraint could be considered during data + augmentation.""" + numerical_only: ClassVar[bool] = False """Class variable encoding whether the constraint is valid only for numerical parameters.""" diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index ee8cf9d0e..468df14ba 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -133,6 +133,10 @@ class DiscreteDependenciesConstraint(DiscreteConstraint): a single constraint. """ + # class variables + eval_during_augmentation: ClassVar[bool] = True + # See base class + # object variables conditions: list[Condition] = field() """The list of individual conditions.""" @@ -220,6 +224,10 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint): evaluated during modeling to make use of the invariance. """ + # class variables + eval_during_augmentation: ClassVar[bool] = True + # See base class + # object variables dependencies: DiscreteDependenciesConstraint | None = field(default=None) """Dependencies connected with the invariant parameters.""" diff --git a/baybe/exceptions.py b/baybe/exceptions.py index d92e20569..94d0324c0 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -9,6 +9,14 @@ class UnusedObjectWarning(UserWarning): """ +class NoSearchspaceMatchWarning(UserWarning): + """The provided input has no match in the searchspace.""" + + +class TooManySearchspaceMatchesWarning(UserWarning): + """The provided input has multiple matches in the searchspace.""" + + ##### Exceptions ##### class NotEnoughPointsLeftError(Exception): """ diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index ab9c2b715..aad40dfb8 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -458,6 +458,19 @@ def full_factorial(self) -> pd.DataFrame: return pd.DataFrame(index=index).reset_index() + def get_parameters_by_name( + self, names: Sequence[str] + ) -> tuple[NumericalContinuousParameter, ...]: + """Return parameters with the specified names. + + Args: + names: Sequence of names. + + Returns: + The named parameters. + """ + return tuple(p for p in self.parameters if p.name in names) + # Register deserialization hook converter.register_structure_hook(SubspaceContinuous, select_constructor_hook) diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index 17f1f49f9..868339477 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -380,6 +380,24 @@ def transform( return comp_rep + @property + def constraints_augmentable(self) -> tuple[Constraint, ...]: + """The searchspace constraints that can be considered during augmentation.""" + return tuple(c for c in self.constraints if c.eval_during_augmentation) + + def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]: + """Return parameters with the specified names. + + Args: + names: Sequence of names. + + Returns: + The named parameters. + """ + return self.discrete.get_parameters_by_name( + names + ) + self.continuous.get_parameters_by_name(names) + def validate_searchspace_from_config(specs: dict, _) -> None: """Validate the search space specifications while skipping costly creation steps.""" diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index 9d41f2f14..11e179e26 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -713,6 +713,19 @@ def transform( except AttributeError: return comp_rep + def get_parameters_by_name( + self, names: Sequence[str] + ) -> tuple[DiscreteParameter, ...]: + """Return parameters with the specified names. + + Args: + names: Sequence of names. + + Returns: + The named parameters. + """ + return tuple(p for p in self.parameters if p.name in names) + def _apply_constraint_filter( df: pd.DataFrame, constraints: Collection[DiscreteConstraint] diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py new file mode 100644 index 000000000..b4b6c5e6e --- /dev/null +++ b/baybe/utils/augmentation.py @@ -0,0 +1,220 @@ +"""Utilities related to data augmentation.""" + +from collections.abc import Sequence +from itertools import permutations, product + +import pandas as pd + + +def _row_in_df(row: pd.Series | pd.DataFrame, df: pd.DataFrame) -> bool: + """Check whether a row is fully contained in a dataframe. + + Args: + row: The row to be checked. + df: The dataframe to be checked. + + Returns: + Boolean result. + + Raises: + ValueError: If ``row`` is a dataframe that contains more than one row. + """ + if isinstance(row, pd.DataFrame): + if len(row) != 1: + raise ValueError( + f"{_row_in_df.__name__} can only be called with pd.Series or " + f"pd.DataFrames that have exactly one row." + ) + row = row.iloc[0] + + row = row.reindex(df.columns) + return (df == row).all(axis=1).any() + + +def df_apply_permutation_augmentation( + df: pd.DataFrame, + columns: Sequence[Sequence[str]], +) -> pd.DataFrame: + """Augment a dataframe if permutation invariant columns are present. + + * Original + + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + + * Result with ``columns = [["A1"], ["A2"]]`` + + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + | b | a | x | y | + +----+----+----+----+ + | a | b | x | z | + +----+----+----+----+ + + * Result with ``columns = [["A1", "B1"], ["A2", "B2"]]`` + + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + | b | a | y | x | + +----+----+----+----+ + | a | b | z | x | + +----+----+----+----+ + + Args: + df: The dataframe that should be augmented. + columns: Sequences of permutation invariant columns. The n'th column in each + sequence will be permuted together with each n'th column in the other + sequences. + + Returns: + The augmented dataframe containing the original one. + + Raises: + ValueError: If ``dependents`` has length incompatible with ``columns``. + ValueError: If entries in ``dependents`` are not of same length. + """ + # Validation + if len(columns) < 2: + raise ValueError( + "When augmenting permutation invariance, at least two column sequences " + "must be given." + ) + if len({len(seq) for seq in columns}) != 1 or len(columns[0]) < 1: + raise ValueError( + "Permutation augmentation can only work if the amount of columns un each " + "sequence is the same and the sequences are not empty." + ) + + # Augmentation Loop + new_rows: list[pd.DataFrame] = [] + idx_permutation = list(permutations(range(len(columns)))) + for _, row in df.iterrows(): + to_add = [] + for _, perm in enumerate(idx_permutation): + new_row = row.copy() + + # Permute columns + for deps in map(list, zip(*columns)): + new_row[deps] = row[[deps[k] for k in perm]] + + # Check whether the new row is an existing permutation + if not _row_in_df(new_row, df): + to_add.append(new_row) + + new_rows.append(pd.DataFrame(to_add)) + augmented_df = pd.concat([df] + new_rows) + + return augmented_df + + +def df_apply_dependency_augmentation( + df: pd.DataFrame, + causing: tuple[str, Sequence], + affected: Sequence[tuple[str, Sequence]], +) -> pd.DataFrame: + """Augment a dataframe if dependency invariant columns are present. + + This works with the concept of column-values pairs for causing and affected column. + Any row present where the specified causing column has one of the provided values + will trigger an augmentation on the affected columns. The latter are augmented by + going through all their invariant values and adding respective new rows. + + * Original + + +---+---+---+---+ + | A | B | C | D | + +===+===+===+===+ + | 0 | 2 | 5 | y | + +---+---+---+---+ + | 1 | 3 | 5 | z | + +---+---+---+---+ + + * Result with ``causing = ("A", [0])``, ``affected = [("B", [2,3,4])]`` + + +---+---+---+---+ + | A | B | C | D | + +===+===+===+===+ + | 0 | 2 | 5 | y | + +---+---+---+---+ + | 1 | 3 | 5 | z | + +---+---+---+---+ + | 0 | 3 | 5 | y | + +---+---+---+---+ + | 0 | 4 | 5 | y | + +---+---+---+---+ + + * Result with ``causing = ("A", [0, 1])``, ``affected = [("B", [2,3])]`` + + +---+---+---+---+ + | A | B | C | D | + +===+===+===+===+ + | 0 | 2 | 5 | y | + +---+---+---+---+ + | 1 | 3 | 5 | z | + +---+---+---+---+ + | 0 | 3 | 5 | y | + +---+---+---+---+ + | 1 | 2 | 5 | z | + +---+---+---+---+ + + * Result with ``causing = ("A", [0])``, + ``affected = [("B", [2,3]), ("C", [5, 6])]`` + + +---+---+---+---+ + | A | B | C | D | + +===+===+===+===+ + | 0 | 2 | 5 | y | + +---+---+---+---+ + | 1 | 3 | 5 | z | + +---+---+---+---+ + | 0 | 3 | 5 | y | + +---+---+---+---+ + | 0 | 2 | 6 | y | + +---+---+---+---+ + | 0 | 3 | 6 | y | + +---+---+---+---+ + + Args: + df: The dataframe that should be augmented. + causing: Causing column name and its causing values. + affected: Affected columns and their invariant values. + + Returns: + The augmented dataframe containing the original one. + """ + new_rows: list[pd.DataFrame] = [] + col_causing, vals_causing = causing + df_filtered = df.loc[df[col_causing].isin(vals_causing), :] + affected_cols, affected_inv_vals = zip(*affected) + affected_inv_vals_combinations = list(product(*affected_inv_vals)) + + # Iterate through all rows that have a causing value in the respective column. + for _, r in df_filtered.iterrows(): + # Create augmented rows + to_add = [ + pd.Series({**r.to_dict(), **dict(zip(affected_cols, values))}) + for values in affected_inv_vals_combinations + ] + + # Do not include rows that were present in the original + to_add = [r2 for r2 in to_add if not _row_in_df(r2, df_filtered)] + new_rows.append(pd.DataFrame(to_add)) + + augmented_df = pd.concat([df] + new_rows) + + return augmented_df diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 5bc09c270..b8d4ed02e 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings from collections.abc import Iterable, Iterator, Sequence from typing import ( TYPE_CHECKING, @@ -13,6 +14,7 @@ import numpy as np import pandas as pd +from baybe.exceptions import NoSearchspaceMatchWarning, TooManySearchspaceMatchesWarning from baybe.targets.enum import TargetMode from baybe.utils.numerical import DTypeFloatNumpy @@ -417,17 +419,17 @@ def fuzzy_row_match( # We expect exactly one match. If that's not the case, print a warning. inds_found = left_df.index[match].to_list() if len(inds_found) == 0 and len(num_cols) > 0: - _logger.warning( - "Input row with index %s could not be matched to the search space. " + warnings.warn( + f"Input row with index {ind} could not be matched to the search space. " "This could indicate that something went wrong.", - ind, + NoSearchspaceMatchWarning, ) elif len(inds_found) > 1: - _logger.warning( - "Input row with index %s has multiple matches with " + warnings.warn( + f"Input row with index {ind} has multiple matches with " "the search space. This could indicate that something went wrong. " "Matching only first occurrence.", - ind, + TooManySearchspaceMatchesWarning, ) inds_matched.append(inds_found[0]) else: diff --git a/tests/test_input_output.py b/tests/test_input_output.py index cc1060795..4ec96184d 100644 --- a/tests/test_input_output.py +++ b/tests/test_input_output.py @@ -1,13 +1,18 @@ """Tests for basic input-output and iterative loop.""" +import warnings + import numpy as np +import pandas as pd import pytest +from baybe.constraints import DiscreteNoLabelDuplicatesConstraint +from baybe.exceptions import NoSearchspaceMatchWarning +from baybe.utils.augmentation import ( + df_apply_dependency_augmentation, + df_apply_permutation_augmentation, +) from baybe.utils.dataframe import add_fake_results -# List of tests that are expected to fail (still missing implementation etc) -param_xfails = [] -target_xfails = [] - @pytest.mark.parametrize( "bad_val", @@ -16,9 +21,6 @@ ) def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, request): """Test attempting to read in an invalid parameter value.""" - if request.node.callspec.id in param_xfails: - pytest.xfail() - rec = campaign.recommend(batch_size=3) add_fake_results( rec, @@ -27,7 +29,11 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req ) # Add an invalid value - rec.Num_disc_1.iloc[0] = bad_val + with warnings.catch_warnings(): + # Ignore warning about incompatible data type assignment + warnings.simplefilter("ignore", FutureWarning) + rec.iloc[0, rec.columns.get_loc("Num_disc_1")] = bad_val + with pytest.raises((ValueError, TypeError)): campaign.add_measurements(rec) @@ -39,9 +45,6 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req ) def test_bad_target_input_value(campaign, good_reference_values, bad_val, request): """Test attempting to read in an invalid target value.""" - if request.node.callspec.id in target_xfails: - pytest.xfail() - rec = campaign.recommend(batch_size=3) add_fake_results( rec, @@ -50,6 +53,131 @@ def test_bad_target_input_value(campaign, good_reference_values, bad_val, reques ) # Add an invalid value - rec.Target_max.iloc[0] = bad_val + with warnings.catch_warnings(): + # Ignore warning about incompatible data type assignment + warnings.simplefilter("ignore", FutureWarning) + rec.iloc[0, rec.columns.get_loc("Target_max")] = bad_val + with pytest.raises((ValueError, TypeError)): campaign.add_measurements(rec) + + +# Reused parameter names for the mixture mock example +_mixture_columns = [ + "Solvent_1", + "Solvent_2", + "Solvent_3", + "Fraction_1", + "Fraction_2", + "Fraction_3", +] + + +@pytest.mark.parametrize("n_grid_points", [5]) +@pytest.mark.parametrize( + "entry", + [ + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 25.0, 75.0]], columns=_mixture_columns + ), + ], +) +@pytest.mark.parametrize("parameter_names", [_mixture_columns]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]] +) +def test_permutation_invariant_input(campaign, entry): + """Test whether permutation invariant measurements can be added.""" + add_fake_results(entry, campaign) + + # Create augmented combinations + entries = df_apply_permutation_augmentation( + entry, + columns=["Solvent_1", "Solvent_2", "Solvent_3"], + dependents=["Fraction_1", "Fraction_2", "Fraction_3"], + ) + + for _, row in entries.iterrows(): + # Reset searchspace metadata + campaign.searchspace.discrete.metadata["was_measured"] = False + + # Assert that not NoSearchspaceMatchWarning is thrown + with warnings.catch_warnings(): + print(row.to_frame().T) + warnings.simplefilter("error", category=NoSearchspaceMatchWarning) + campaign.add_measurements(pd.DataFrame([row])) + + # Assert exactly one searchspace entry has been marked + num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum() + assert num_nonzero == 1, ( + "Measurement ingestion was successful, but did not correctly update the " + f"searchspace metadata. Number of non-zero entries: {num_nonzero} " + f"(expected 1)" + ) + + +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize( + "entry", + [ + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 25.0, 75.0]], + columns=_mixture_columns, + ), + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 0.0, 50.0]], + columns=_mixture_columns, + ), + ], + ids=["single_degen", "double_degen"], +) +@pytest.mark.parametrize("parameter_names", [_mixture_columns]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]] +) +def test_dependency_invariant_input(campaign, entry): + """Test whether dependency invariant measurements can be added.""" + # Get an entry from the searchspace + add_fake_results(entry, campaign) + sol_vals = campaign.searchspace.get_parameters_by_name(["Solvent_1"])[0].values + + # Create augmented combinations + entries = df_apply_dependency_augmentation( + entry, causing=("Fraction_1", [0.0]), affected=[("Solvent_1", sol_vals)] + ) + entries = df_apply_dependency_augmentation( + entries, causing=("Fraction_2", [0.0]), affected=[("Solvent_2", sol_vals)] + ) + entries = df_apply_dependency_augmentation( + entries, causing=("Fraction_3", [0.0]), affected=[("Solvent_3", sol_vals)] + ) + + # Remove falsely created label duplicates + entries.reset_index(drop=True, inplace=True) + for c in campaign.searchspace.discrete.constraints: + if isinstance(c, DiscreteNoLabelDuplicatesConstraint): + entries.drop(index=c.get_invalid(entries), inplace=True) + + # Add nan entries for testing nan input in the invariant parameters + entry_nan = entry.copy() + entry_nan.loc[entry_nan["Fraction_1"] == 0.0, "Solvent_1"] = np.nan + entry_nan.loc[entry_nan["Fraction_2"] == 0.0, "Solvent_2"] = np.nan + entry_nan.loc[entry_nan["Fraction_3"] == 0.0, "Solvent_3"] = np.nan + + for _, row in pd.concat([entries, entry_nan]).iterrows(): + # Reset searchspace metadata + campaign.searchspace.discrete.metadata["was_measured"] = False + + # Assert that not NoSearchspaceMatchWarning is thrown + with warnings.catch_warnings(): + print(row.to_frame().T) + warnings.simplefilter("error", category=NoSearchspaceMatchWarning) + campaign.add_measurements(pd.DataFrame([row])) + + # Assert exactly one searchspace entry has been marked + num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum() + assert num_nonzero == 1, ( + "Measurement ingestion was successful, but did not correctly update the " + f"searchspace metadata. Number of non-zero entries: {num_nonzero} " + f"(expected 1)" + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 956a7b51e..a2bc6612e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,10 @@ import pytest from pytest import param +from baybe.utils.augmentation import ( + df_apply_dependency_augmentation, + df_apply_permutation_augmentation, +) from baybe.utils.basic import register_hooks from baybe.utils.memory import bytes_to_human_readable from baybe.utils.numerical import closest_element @@ -120,3 +124,243 @@ def test_invalid_register_hooks(target, hook): """Passing inconsistent signatures to `register_hooks` raises an error.""" with pytest.raises(TypeError): register_hooks(target, [hook]) + + +@pytest.mark.parametrize( + ("data", "columns", "data_expected"), + [ + param( # 2 invariant cols and 1 unaffected col + { + "A": [1, 1], + "B": [2, 2], + "C": ["x", "y"], + }, + [["A"], ["B"]], + { + "A": [1, 2, 1, 2], + "B": [2, 1, 2, 1], + "C": ["x", "x", "y", "y"], + }, + id="2inv_1add", + ), + param( # 2 invariant cols with identical values + { + "A": [1, 1], + "B": [2, 2], + }, + [["A"], ["B"]], + { + "A": [1, 1, 2], + "B": [2, 2, 1], + }, + id="2inv+degen", + ), + param( # 2 invariant cols with identical values but different targets + { + "A": [1, 1], + "B": [2, 2], + "T": ["x", "y"], + }, + [["A"], ["B"]], + { + "A": [1, 1, 2, 2], + "B": [2, 2, 1, 1], + "T": ["x", "y", "x", "y"], + }, + id="2inv+degen_target", + ), + param( # 2 invariant cols with identical values but different targets + { + "A": [1, 1], + "B": [2, 2], + "T": ["x", "x"], + }, + [["A"], ["B"]], + { + "A": [1, 2], + "B": [2, 1], + "T": ["x", "x"], + }, + id="2inv+degen_target+degen", + ), + param( # 3 invariant cols + { + "A": [1, 1], + "B": [2, 4], + "C": [3, 5], + "D": ["x", "y"], + }, + [["A"], ["B"], ["C"]], + { + "A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5], + "B": [2, 3, 1, 3, 2, 1, 4, 5, 1, 5, 1, 4], + "C": [3, 2, 3, 1, 1, 2, 5, 4, 5, 1, 4, 1], + "D": ["x", "x", "x", "x", "x", "x", "y", "y", "y", "y", "y", "y"], + }, + id="3inv_1add", + ), + param( # 2 invariant cols, 2 dependent ones, 2 additional ones + { + "Slot1": ["s1", "s2"], + "Slot2": ["s2", "s4"], + "Frac1": [0.1, 0.6], + "Frac2": [0.9, 0.4], + "Other1": ["A", "B"], + "Other2": ["C", "D"], + }, + [["Slot1", "Frac1"], ["Slot2", "Frac2"]], + { + "Slot1": ["s1", "s2", "s2", "s4"], + "Slot2": ["s2", "s4", "s1", "s2"], + "Frac1": [0.1, 0.6, 0.9, 0.4], + "Frac2": [0.9, 0.4, 0.1, 0.6], + "Other1": ["A", "B", "A", "B"], + "Other2": ["C", "D", "C", "D"], + }, + id="2inv_2dependent_2add", + ), + param( # 2 invariant cols, 2 dependent ones, 2 additional ones + { + "Slot1": ["s1", "s2"], + "Slot2": ["s2", "s4"], + "Frac1": [0.1, 0.6], + "Frac2": [0.9, 0.4], + "Temp1": [10, 20], + "Temp2": [50, 60], + "Other": ["x", "y"], + }, + [["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]], + { + "Slot1": ["s1", "s2", "s2", "s4"], + "Slot2": ["s2", "s4", "s1", "s2"], + "Frac1": [0.1, 0.6, 0.9, 0.4], + "Frac2": [0.9, 0.4, 0.1, 0.6], + "Temp1": [10, 20, 50, 60], + "Temp2": [50, 60, 10, 20], + "Other": ["x", "y", "x", "y"], + }, + id="2inv_4dependent2each_1add", + ), + ], +) +def test_df_permutation_aug(data, columns, data_expected): + """Test permutation invariance data augmentation is done correctly.""" + # Create all needed dataframes + df = pd.DataFrame(data) + df_augmented = df_apply_permutation_augmentation(df, columns) + df_expected = pd.DataFrame(data_expected) + + # Determine equality ignoring row order + are_equal = ( + pd.merge(left=df_augmented, right=df_expected, how="outer", indicator=True)[ + "_merge" + ] + .eq("both") + .all() + ) + + assert ( + are_equal + ), f"\norig:\n{df}\n\naugmented:\n{df_augmented}\n\nexpected:\n{df_expected}" + + +@pytest.mark.parametrize( + ("columns", "msg"), + [ + param([], "at least two column sequences", id="no_seqs"), + param([["A"]], "at least two column sequences", id="just_one_seq"), + param([["A"], ["B", "C"]], "sequence is the same", id="different_lengths"), + param([[], []], "sequence is the same", id="empty_seqs"), + ], +) +def test_df_permutation_aug_invalid(columns, msg): + """Test correct errors for invalid permutation attempts.""" + df = pd.DataFrame({"A": [1, 1], "B": [2, 2], "C": ["x", "y"]}) + with pytest.raises(ValueError, match=msg): + df_apply_permutation_augmentation(df, columns) + + +@pytest.mark.parametrize( + ("data", "causing", "affected", "data_expected"), + [ + param( # 1 causing val, 1 col affected (with 3 values) + { + "A": [0, 1], + "B": [3, 4], + "C": ["x", "y"], + }, + ("A", [0]), + [("B", [3, 4, 5])], + { + "A": [0, 1, 0, 0], + "B": [3, 4, 4, 5], + "C": ["x", "y", "x", "x"], + }, + id="1causing_1affected", + ), + param( # 1 causing val, 2 cols affected (with 2 values each) + { + "A": [0, 1], + "B": [3, 4], + "C": ["x", "y"], + }, + ("A", [0]), + [("B", [3, 4]), ("C", ["x", "y"])], + { + "A": [0, 1, 0, 0, 0], + "B": [3, 4, 4, 3, 4], + "C": ["x", "y", "x", "y", "y"], + }, + id="1causing_2affected", + ), + param( # 2 causing vals, 1 col affected (with 3 values) + { + "A": [0, 1, 2], + "B": [3, 4, 3], + "C": ["x", "y", "z"], + }, + ("A", [0, 1]), + [("B", [3, 4, 5])], + { + "A": [0, 1, 2, 0, 0, 1, 1], + "B": [3, 4, 3, 4, 5, 3, 5], + "C": ["x", "y", "z", "x", "x", "y", "y"], + }, + id="2causing_1affected", + ), + param( # 2 causing vals, 2 cols affected (with 2 values each) + { + "A": [0, 1, 2], + "B": [3, 4, 3], + "C": ["x", "y", "x"], + }, + ("A", [0, 1]), + [("B", [3, 4]), ("C", ["x", "y"])], + { + "A": [0, 1, 2, 0, 0, 0, 1, 1, 1], + "B": [3, 4, 3, 4, 3, 4, 3, 3, 4], + "C": ["x", "y", "x", "x", "y", "y", "y", "x", "x"], + }, + id="2causing_2affected", + ), + ], +) +def test_df_dependency_aug(data, causing, affected, data_expected): + """Test dependency data augmentation is done correctly.""" + # Create all needed dataframes + df = pd.DataFrame(data) + df_augmented = df_apply_dependency_augmentation(df, causing, affected) + df_expected = pd.DataFrame(data_expected) + + # Determine equality ignoring row order + are_equal = ( + pd.merge(left=df_augmented, right=df_expected, how="outer", indicator=True)[ + "_merge" + ] + .eq("both") + .all() + ) + + assert ( + are_equal + ), f"\norig:\n{df}\n\naugmented:\n{df_augmented}\n\nexpected:\n{df_expected}"