diff --git a/grainlearning/sampling.py b/grainlearning/sampling.py index 0c05edd..3335185 100644 --- a/grainlearning/sampling.py +++ b/grainlearning/sampling.py @@ -85,6 +85,8 @@ class GaussianMixtureModel: This can speed up convergence when fit is called several times on similar problems. See the Glossary. :param expand_factor: factor used when converting the ensemble from weighted to unweighted, defaults to 10, optional :param slice_sampling: flag to use slice sampling, defaults to False, optional + :param deviation_factor: factor used to determine the threshold, + scores.mean() - deviation_factor * scores.std(), for slice sampling, defaults to 0.0, optional :param gmm: The class of the Gaussian Mixture Model :param max_params: Current maximum values of the parameters """ @@ -102,6 +104,7 @@ def __init__( warm_start: bool = False, expand_factor: int = 10, slice_sampling: bool = False, + deviation_factor: float = 0.0, ): """ Initialize the Gaussian Mixture Model class""" self.max_num_components = max_num_components @@ -129,6 +132,8 @@ def __init__( self.slice_sampling = slice_sampling + self.deviation_factor = deviation_factor + self.max_params = None self.expanded_normalized_params = None @@ -154,6 +159,7 @@ def from_dict(cls: Type["GaussianMixtureModel"], obj: dict): warm_start=obj.get("warm_start", False), expand_factor=obj.get("expand_factor", 10), slice_sampling=obj.get("slice_sampling", False), + deviation_factor=obj.get("deviation_factor", 0.0), ) def expand_and_normalize_weighted_samples(self, weights: np.ndarray, system: Type["DynamicSystem"]): @@ -245,7 +251,7 @@ def draw_samples_within_bounds(self, system: Type["DynamicSystem"], num: int = 1 scores = self.gmm.score_samples(self.expanded_normalized_params) new_params = new_params[np.where( - self.gmm.score_samples(new_params) > scores.mean())] + self.gmm.score_samples(new_params) > scores.mean() - self.deviation_factor * scores.std())] new_params *= self.max_params