Skip to content

Commit

Permalink
Merge pull request #12 from AutoResearch/feat-make-compatible-with-state
Browse files Browse the repository at this point in the history
feat: make compatible with state and work with pd.DataFrames
  • Loading branch information
younesStrittmatter authored Sep 3, 2023
2 parents 1c14872 + aebc1ad commit daff328
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions src/autora/experimentalist/nearest_value/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from typing import Iterable, Sequence, Union
from typing import Iterable, Union

import numpy as np
import pandas as pd

from autora.utils.deprecation import deprecated_alias


def sample(
samples: Union[Iterable, Sequence],
conditions: Union[pd.DataFrame, np.ndarray],
allowed_values: np.ndarray,
num_samples: int,
):
"""
A sampler which returns the nearest values between the input samples and the allowed values,
without replacement.
A experimentalist which returns the nearest values between the input samples and the allowed
values, without replacement.
Args:
samples: input conditions
conditions: input conditions
allowed_values: allowed conditions to sample from
num_samples: number of samples
Expand All @@ -30,27 +31,20 @@ def sample(
if len(allowed_values.shape) == 1:
allowed_values = allowed_values.reshape(-1, 1)

if isinstance(samples, Iterable):
samples = np.array(list(samples))

if allowed_values.shape[0] < num_samples:
raise Exception(
"More samples requested than samples available in the set allowed of values."
)

if isinstance(samples, Iterable) or isinstance(samples, Sequence):
samples = np.array(list(samples))
X = np.array(conditions)

if hasattr(samples, "shape"):
if samples.shape[0] < num_samples:
raise Exception(
"More samples requested than samples available in the pool."
)
if X.shape[0] < num_samples:
raise Exception("More samples requested than samples available in the pool.")

x_new = np.empty((num_samples, allowed_values.shape[1]))

# get index of row in x that is closest to each sample
for row, sample in enumerate(samples):
for row, sample in enumerate(X):

if row >= num_samples:
break
Expand All @@ -60,6 +54,9 @@ def sample(
x_new[row, :] = allowed_values[idx, :]
allowed_values = np.delete(allowed_values, idx, axis=0)

if isinstance(conditions, pd.DataFrame):
x_new = pd.DataFrame(x_new, columns=conditions.columns)

return x_new


Expand Down

0 comments on commit daff328

Please sign in to comment.