Skip to content

Commit

Permalink
Stabilise KL test
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Aug 19, 2024
1 parent fb989de commit b98a881
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
47 changes: 42 additions & 5 deletions gpjax/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,42 @@

from gpjax.typing import (
Array,
MultivariateParams,
ScalarFloat,
UnivariateParams,
)

_JITTER = 1e-6


@st.composite
def sample_univariate_gaussian_params(draw) -> tp.Tuple[ScalarFloat, ScalarFloat]:
def stable_float(draw, min_value=-10, max_value=10, abs_threshold=1e-6) -> float:
"""
Strategy to generate floats between min_value and max_value,
excluding values with absolute value less than abs_threshold.
Parameters:
- min_value: minimum value for the float (default: -10)
- max_value: maximum value for the float (default: 10)
- abs_threshold: absolute values below this will be excluded (default: 1e-6)
"""
value = draw(
st.floats(
min_value=min_value,
max_value=max_value,
allow_infinity=False,
allow_nan=False,
)
)
# If the value is too close to zero, move it outside the excluded range
if abs(value) < abs_threshold:
sign = 1 if value >= 0 else -1
value = sign * (abs_threshold + (abs(value) % abs_threshold))
return value


@st.composite
def sample_univariate_gaussian_params(draw) -> UnivariateParams:
loc = jnp.array(draw(st.floats()), dtype=jnp.float64)
scale = jnp.array(
draw(st.floats(min_value=_JITTER, exclude_min=True)), dtype=jnp.float64
Expand All @@ -22,11 +50,13 @@ def sample_univariate_gaussian_params(draw) -> tp.Tuple[ScalarFloat, ScalarFloat


@st.composite
def sample_multivariate_gaussian_params(
draw, dim: int
) -> tp.Tuple[Float[Array, " N"], Float[Array, "N N"]]:
def sample_multivariate_gaussian_params(draw, dim: int) -> MultivariateParams:
mean = draw(
arrays(dtype=float, shape=dim, elements=st.floats(min_value=-10, max_value=10))
arrays(
dtype=float,
shape=dim,
elements=stable_float(min_value=-10, max_value=10, abs_threshold=1e-3),
)
)
lower_vals = draw(
arrays(
Expand All @@ -46,3 +76,10 @@ def is_psd(matrix: Float[Array, "N N"]) -> bool:
psd_status = jnp.all(eig_vals > 0.0)
# except jnp.linalg.Lin
return psd_status


def approx_equal(
res: Float[Array, "N N"], actual: Float[Array, "N N"], threshold: float = 1e-6
) -> bool:
"""Check if two arrays are approximately equal."""
return jnp.linalg.norm(res - actual) < threshold
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax # noqa: F401

settings.register_profile("ci", max_examples=100, deadline=None)
settings.register_profile("ci", max_examples=300, deadline=None)
settings.register_profile("local_dev", max_examples=5, deadline=None)
16 changes: 8 additions & 8 deletions tests/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
import pytest

from gpjax.distributions import GaussianDistribution
from gpjax.testing import sample_multivariate_gaussian_params
from gpjax.testing import (
approx_equal,
sample_multivariate_gaussian_params,
)
from gpjax.typing import MultivariateParams

_key = jr.key(seed=42)
Expand All @@ -44,12 +47,7 @@
)

MIN_DIM = 1
MAX_DIM = 50


def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool:
"""Check if two arrays are approximately equal."""
return jnp.linalg.norm(res - actual) < 1e-5
MAX_DIM = 20


@given(dim=st.integers(min_value=MIN_DIM, max_value=MAX_DIM), data=st.data())
Expand Down Expand Up @@ -159,7 +157,9 @@ def test_kl_divergence(dim: int, data: st.DataObject) -> None:
)

assert approx_equal(
dist_a.kl_divergence(dist_b), tfp_dist_a.kl_divergence(tfp_dist_b)
dist_a.kl_divergence(dist_b),
tfp_dist_a.kl_divergence(tfp_dist_b),
threshold=1e-3,
)

with pytest.raises(ValueError):
Expand Down

0 comments on commit b98a881

Please sign in to comment.