Skip to content

Commit

Permalink
add a parameter to the sampling module to allow user defining the thr…
Browse files Browse the repository at this point in the history
…eshold for slice sampling
  • Loading branch information
chyalexcheng committed Feb 8, 2024
1 parent c2ddbfe commit 76b24fc
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion grainlearning/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class GaussianMixtureModel:
This can speed up convergence when fit is called several times on similar problems. See the Glossary.
:param expand_factor: factor used when converting the ensemble from weighted to unweighted, defaults to 10, optional
:param slice_sampling: flag to use slice sampling, defaults to False, optional
:param deviation_factor: factor used to determine the threshold,
scores.mean() - deviation_factor * scores.std(), for slice sampling, defaults to 0.0, optional
:param gmm: The class of the Gaussian Mixture Model
:param max_params: Current maximum values of the parameters
"""
Expand All @@ -102,6 +104,7 @@ def __init__(
warm_start: bool = False,
expand_factor: int = 10,
slice_sampling: bool = False,
deviation_factor: float = 0.0,
):
""" Initialize the Gaussian Mixture Model class"""
self.max_num_components = max_num_components
Expand Down Expand Up @@ -129,6 +132,8 @@ def __init__(

self.slice_sampling = slice_sampling

self.deviation_factor = deviation_factor

self.max_params = None

self.expanded_normalized_params = None
Expand All @@ -154,6 +159,7 @@ def from_dict(cls: Type["GaussianMixtureModel"], obj: dict):
warm_start=obj.get("warm_start", False),
expand_factor=obj.get("expand_factor", 10),
slice_sampling=obj.get("slice_sampling", False),
deviation_factor=obj.get("deviation_factor", 0.0),
)

def expand_and_normalize_weighted_samples(self, weights: np.ndarray, system: Type["DynamicSystem"]):
Expand Down Expand Up @@ -245,7 +251,7 @@ def draw_samples_within_bounds(self, system: Type["DynamicSystem"], num: int = 1

scores = self.gmm.score_samples(self.expanded_normalized_params)
new_params = new_params[np.where(
self.gmm.score_samples(new_params) > scores.mean())]
self.gmm.score_samples(new_params) > scores.mean() - self.deviation_factor * scores.std())]

new_params *= self.max_params

Expand Down

0 comments on commit 76b24fc

Please sign in to comment.