diff --git a/src/autora/experimentalist/nearest_value/__init__.py b/src/autora/experimentalist/nearest_value/__init__.py index 4552b44..6005508 100644 --- a/src/autora/experimentalist/nearest_value/__init__.py +++ b/src/autora/experimentalist/nearest_value/__init__.py @@ -1,21 +1,22 @@ -from typing import Iterable, Sequence, Union +from typing import Iterable, Union import numpy as np +import pandas as pd from autora.utils.deprecation import deprecated_alias def sample( - samples: Union[Iterable, Sequence], + conditions: Union[pd.DataFrame, np.ndarray], allowed_values: np.ndarray, num_samples: int, ): """ - A sampler which returns the nearest values between the input samples and the allowed values, - without replacement. + A experimentalist which returns the nearest values between the input samples and the allowed + values, without replacement. Args: - samples: input conditions + conditions: input conditions allowed_values: allowed conditions to sample from num_samples: number of samples @@ -30,27 +31,20 @@ def sample( if len(allowed_values.shape) == 1: allowed_values = allowed_values.reshape(-1, 1) - if isinstance(samples, Iterable): - samples = np.array(list(samples)) - if allowed_values.shape[0] < num_samples: raise Exception( "More samples requested than samples available in the set allowed of values." ) - if isinstance(samples, Iterable) or isinstance(samples, Sequence): - samples = np.array(list(samples)) + X = np.array(conditions) - if hasattr(samples, "shape"): - if samples.shape[0] < num_samples: - raise Exception( - "More samples requested than samples available in the pool." - ) + if X.shape[0] < num_samples: + raise Exception("More samples requested than samples available in the pool.") x_new = np.empty((num_samples, allowed_values.shape[1])) # get index of row in x that is closest to each sample - for row, sample in enumerate(samples): + for row, sample in enumerate(X): if row >= num_samples: break @@ -60,6 +54,9 @@ def sample( x_new[row, :] = allowed_values[idx, :] allowed_values = np.delete(allowed_values, idx, axis=0) + if isinstance(conditions, pd.DataFrame): + x_new = pd.DataFrame(x_new, columns=conditions.columns) + return x_new