Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Input With Invariant Parameters #291

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ _ `_optional` subpackage for managing optional dependencies
- `register_hooks` utility enabling user-defined augmentation of arbitrary callables
- `transform` methods of `SearchSpace`, `SubspaceDiscrete` and `SubspaceContinuous`
now take additional `allow_missing` and `allow_extra` keyword arguments
- Utilities for permutation and dependency data augmentation

### Changed
- Passing an `Objective` to `Campaign` is now optional
Expand Down
4 changes: 4 additions & 0 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class Constraint(ABC, SerialMixin):
eval_during_modeling: ClassVar[bool]
"""Class variable encoding whether the condition is evaluated during modeling."""

eval_during_augmentation: ClassVar[bool] = False
"""Class variable encoding whether the constraint could be considered during data
augmentation."""

numerical_only: ClassVar[bool] = False
"""Class variable encoding whether the constraint is valid only for numerical
parameters."""
Expand Down
8 changes: 8 additions & 0 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class DiscreteDependenciesConstraint(DiscreteConstraint):
a single constraint.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
conditions: list[Condition] = field()
"""The list of individual conditions."""
Expand Down Expand Up @@ -220,6 +224,10 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint):
evaluated during modeling to make use of the invariance.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
dependencies: DiscreteDependenciesConstraint | None = field(default=None)
"""Dependencies connected with the invariant parameters."""
Expand Down
8 changes: 8 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ class UnusedObjectWarning(UserWarning):
"""


class NoSearchspaceMatchWarning(UserWarning):
"""The provided input has no match in the searchspace."""


class TooManySearchspaceMatchesWarning(UserWarning):
"""The provided input has multiple matches in the searchspace."""


##### Exceptions #####
class NotEnoughPointsLeftError(Exception):
"""
Expand Down
13 changes: 13 additions & 0 deletions baybe/searchspace/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,19 @@ def full_factorial(self) -> pd.DataFrame:

return pd.DataFrame(index=index).reset_index()

def get_parameters_by_name(
self, names: Sequence[str]
) -> tuple[NumericalContinuousParameter, ...]:
"""Return parameters with the specified names.

Args:
names: Sequence of names.

Returns:
The named parameters.
"""
return tuple(p for p in self.parameters if p.name in names)


# Register deserialization hook
converter.register_structure_hook(SubspaceContinuous, select_constructor_hook)
18 changes: 18 additions & 0 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,24 @@ def transform(

return comp_rep

@property
def constraints_augmentable(self) -> tuple[Constraint, ...]:
"""The searchspace constraints that can be considered during augmentation."""
return tuple(c for c in self.constraints if c.eval_during_augmentation)

def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]:
"""Return parameters with the specified names.

Args:
names: Sequence of names.

Returns:
The named parameters.
"""
return self.discrete.get_parameters_by_name(
names
) + self.continuous.get_parameters_by_name(names)


def validate_searchspace_from_config(specs: dict, _) -> None:
"""Validate the search space specifications while skipping costly creation steps."""
Expand Down
13 changes: 13 additions & 0 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,19 @@ def transform(
except AttributeError:
return comp_rep

def get_parameters_by_name(
self, names: Sequence[str]
) -> tuple[DiscreteParameter, ...]:
"""Return parameters with the specified names.

Args:
names: Sequence of names.

Returns:
The named parameters.
"""
return tuple(p for p in self.parameters if p.name in names)


def _apply_constraint_filter(
df: pd.DataFrame, constraints: Collection[DiscreteConstraint]
Expand Down
220 changes: 220 additions & 0 deletions baybe/utils/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Utilities related to data augmentation."""

from collections.abc import Sequence
from itertools import permutations, product

import pandas as pd


def _row_in_df(row: pd.Series | pd.DataFrame, df: pd.DataFrame) -> bool:
"""Check whether a row is fully contained in a dataframe.

Args:
row: The row to be checked.
df: The dataframe to be checked.

Returns:
Boolean result.

Raises:
ValueError: If ``row`` is a dataframe that contains more than one row.
"""
if isinstance(row, pd.DataFrame):
if len(row) != 1:
raise ValueError(
f"{_row_in_df.__name__} can only be called with pd.Series or "
f"pd.DataFrames that have exactly one row."
)
row = row.iloc[0]

row = row.reindex(df.columns)
return (df == row).all(axis=1).any()


def df_apply_permutation_augmentation(
df: pd.DataFrame,
columns: Sequence[Sequence[str]],
) -> pd.DataFrame:
"""Augment a dataframe if permutation invariant columns are present.

* Original

+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+

* Result with ``columns = [["A1"], ["A2"]]``

+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+
| b | a | x | y |
+----+----+----+----+
| a | b | x | z |
+----+----+----+----+

* Result with ``columns = [["A1", "B1"], ["A2", "B2"]]``

+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+
| b | a | y | x |
+----+----+----+----+
| a | b | z | x |
+----+----+----+----+

Args:
df: The dataframe that should be augmented.
columns: Sequences of permutation invariant columns. The n'th column in each
sequence will be permuted together with each n'th column in the other
sequences.

Returns:
The augmented dataframe containing the original one.

Raises:
ValueError: If ``dependents`` has length incompatible with ``columns``.
ValueError: If entries in ``dependents`` are not of same length.
"""
# Validation
if len(columns) < 2:
raise ValueError(
"When augmenting permutation invariance, at least two column sequences "
"must be given."
)
if len({len(seq) for seq in columns}) != 1 or len(columns[0]) < 1:
raise ValueError(
"Permutation augmentation can only work if the amount of columns un each "
"sequence is the same and the sequences are not empty."
)

# Augmentation Loop
new_rows: list[pd.DataFrame] = []
idx_permutation = list(permutations(range(len(columns))))
for _, row in df.iterrows():
to_add = []
for _, perm in enumerate(idx_permutation):
new_row = row.copy()

# Permute columns
for deps in map(list, zip(*columns)):
new_row[deps] = row[[deps[k] for k in perm]]

# Check whether the new row is an existing permutation
if not _row_in_df(new_row, df):
to_add.append(new_row)

new_rows.append(pd.DataFrame(to_add))
augmented_df = pd.concat([df] + new_rows)

return augmented_df


def df_apply_dependency_augmentation(
df: pd.DataFrame,
causing: tuple[str, Sequence],
affected: Sequence[tuple[str, Sequence]],
) -> pd.DataFrame:
"""Augment a dataframe if dependency invariant columns are present.

This works with the concept of column-values pairs for causing and affected column.
Any row present where the specified causing column has one of the provided values
will trigger an augmentation on the affected columns. The latter are augmented by
going through all their invariant values and adding respective new rows.

* Original

+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| 0 | 2 | 5 | y |
+---+---+---+---+
| 1 | 3 | 5 | z |
+---+---+---+---+

* Result with ``causing = ("A", [0])``, ``affected = [("B", [2,3,4])]``

+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| 0 | 2 | 5 | y |
+---+---+---+---+
| 1 | 3 | 5 | z |
+---+---+---+---+
| 0 | 3 | 5 | y |
+---+---+---+---+
| 0 | 4 | 5 | y |
+---+---+---+---+

* Result with ``causing = ("A", [0, 1])``, ``affected = [("B", [2,3])]``

+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| 0 | 2 | 5 | y |
+---+---+---+---+
| 1 | 3 | 5 | z |
+---+---+---+---+
| 0 | 3 | 5 | y |
+---+---+---+---+
| 1 | 2 | 5 | z |
+---+---+---+---+

* Result with ``causing = ("A", [0])``,
``affected = [("B", [2,3]), ("C", [5, 6])]``

+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| 0 | 2 | 5 | y |
+---+---+---+---+
| 1 | 3 | 5 | z |
+---+---+---+---+
| 0 | 3 | 5 | y |
+---+---+---+---+
| 0 | 2 | 6 | y |
+---+---+---+---+
| 0 | 3 | 6 | y |
+---+---+---+---+

Args:
df: The dataframe that should be augmented.
causing: Causing column name and its causing values.
affected: Affected columns and their invariant values.

Returns:
The augmented dataframe containing the original one.
"""
new_rows: list[pd.DataFrame] = []
col_causing, vals_causing = causing
df_filtered = df.loc[df[col_causing].isin(vals_causing), :]
affected_cols, affected_inv_vals = zip(*affected)
affected_inv_vals_combinations = list(product(*affected_inv_vals))

# Iterate through all rows that have a causing value in the respective column.
for _, r in df_filtered.iterrows():
# Create augmented rows
to_add = [
pd.Series({**r.to_dict(), **dict(zip(affected_cols, values))})
for values in affected_inv_vals_combinations
]

# Do not include rows that were present in the original
to_add = [r2 for r2 in to_add if not _row_in_df(r2, df_filtered)]
new_rows.append(pd.DataFrame(to_add))

augmented_df = pd.concat([df] + new_rows)

return augmented_df
14 changes: 8 additions & 6 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterable, Iterator, Sequence
from typing import (
TYPE_CHECKING,
Expand All @@ -13,6 +14,7 @@
import numpy as np
import pandas as pd

from baybe.exceptions import NoSearchspaceMatchWarning, TooManySearchspaceMatchesWarning
from baybe.targets.enum import TargetMode
from baybe.utils.numerical import DTypeFloatNumpy

Expand Down Expand Up @@ -417,17 +419,17 @@ def fuzzy_row_match(
# We expect exactly one match. If that's not the case, print a warning.
inds_found = left_df.index[match].to_list()
if len(inds_found) == 0 and len(num_cols) > 0:
_logger.warning(
"Input row with index %s could not be matched to the search space. "
warnings.warn(
f"Input row with index {ind} could not be matched to the search space. "
"This could indicate that something went wrong.",
ind,
NoSearchspaceMatchWarning,
)
elif len(inds_found) > 1:
_logger.warning(
"Input row with index %s has multiple matches with "
warnings.warn(
f"Input row with index {ind} has multiple matches with "
"the search space. This could indicate that something went wrong. "
"Matching only first occurrence.",
ind,
TooManySearchspaceMatchesWarning,
)
inds_matched.append(inds_found[0])
else:
Expand Down
Loading
Loading