diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index b6d07cc2b..4a71d8969 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -39,7 +39,7 @@ jobs: # Install the dependencies - name: Install Package run: | - poetry install --all-extras --with docs + poetry install --with docs # Run the unit tests and build the coverage report - name: Run Integration Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51464c854..958818a7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,14 +46,14 @@ repos: language: system types: [python] exclude: examples/ - - repo: https://github.com/econchick/interrogate - rev: 1.5.0 - hooks: - - id: interrogate - args: - [ - "gpjax", - "--config", - "pyproject.toml", - ] - pass_filenames: false + # - repo: https://github.com/econchick/interrogate + # rev: 1.5.0 + # hooks: + # - id: interrogate + # args: + # [ + # "gpjax", + # "--config", + # "pyproject.toml", + # ] + # pass_filenames: false diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py index e2e21fd9f..97120f70f 100644 --- a/docs/examples/barycentres.py +++ b/docs/examples/barycentres.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + # %% [markdown] # # Gaussian Processes Barycentres # diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py index e6486f702..cc3c1ee0f 100644 --- a/docs/examples/constructing_new_kernels.py +++ b/docs/examples/constructing_new_kernels.py @@ -27,9 +27,6 @@ config.update("jax_enable_x64", True) -from dataclasses import dataclass - -from jax import jit import jax.numpy as jnp import jax.random as jr from jaxtyping import ( @@ -84,11 +81,11 @@ meanf = gpx.mean_functions.Zero() -for k, ax in zip(kernels, axes.ravel()): +for k, ax, c in zip(kernels, axes.ravel(), cols): prior = gpx.gps.Prior(mean_function=meanf, kernel=k) rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) - ax.plot(x, y.T, alpha=0.7) + ax.plot(x, y.T, alpha=0.7, color=c) ax.set_title(k.name) # %% [markdown] @@ -205,24 +202,42 @@ # %% +from gpjax.kernels.computations import DenseKernelComputation +from gpjax.parameters import DEFAULT_BIJECTION, Static, PositiveReal + + def angular_distance(x, y, c): return jnp.abs((x - y + c) % (c * 2) - c) bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64)) +DEFAULT_BIJECTION["polar"] = bij + -@dataclass class Polar(gpx.kernels.AbstractKernel): - period: float = static_field(2 * jnp.pi) - tau: float = param_field(jnp.array([5.0]), bijector=bij) + period: Static + tau: PositiveReal + + def __init__( + self, + tau: float = 5.0, + period: float = 2 * jnp.pi, + active_dims: list[int] | slice | None = None, + n_dims: int | None = None, + ): + super().__init__(active_dims, n_dims, DenseKernelComputation()) + self.period = Static(jnp.array(period)) + self.tau = PositiveReal(jnp.array(tau), tag="polar") def __call__( self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: - c = self.period / 2.0 + c = self.period.value / 2.0 t = angular_distance(x, y, c) - K = (1 + self.tau * t / c) * jnp.clip(1 - t / c, 0, jnp.inf) ** self.tau + K = (1 + self.tau.value * t / c) * jnp.clip( + 1 - t / c, 0, jnp.inf + ) ** self.tau.value return K.squeeze() @@ -265,7 +280,7 @@ def __call__( # Optimise GP's marginal log-likelihood using BFGS opt_posterior, history = gpx.fit_scipy( model=circular_posterior, - objective=jit(gpx.objectives.ConjugateMLL(negative=True)), + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=D, ) diff --git a/docs/examples/deep_kernels.py b/docs/examples/deep_kernels.py index 638d3329e..ddf0b0420 100644 --- a/docs/examples/deep_kernels.py +++ b/docs/examples/deep_kernels.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + # %% [markdown] # # Deep Kernel Learning # @@ -18,10 +35,12 @@ dataclass, field, ) -from typing import Any -import flax -from flax import linen as nn +from flax.experimental import nnx +from gpjax.kernels.computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) import jax import jax.numpy as jnp import jax.random as jr @@ -95,25 +114,17 @@ # %% @dataclass class DeepKernelFunction(AbstractKernel): - base_kernel: AbstractKernel = None - network: nn.Module = static_field(None) - dummy_x: jax.Array = static_field(None) - key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123)) - nn_params: Any = field(init=False, repr=False) - - def __post_init__(self): - if self.base_kernel is None: - raise ValueError("base_kernel must be specified") - if self.network is None: - raise ValueError("network must be specified") - self.nn_params = flax.core.unfreeze(self.network.init(key, self.dummy_x)) + base_kernel: AbstractKernel + network: nnx.Module + compute_engine: AbstractKernelComputation = field( + default_factory=lambda: DenseKernelComputation() + ) def __call__( self, x: Float[Array, " D"], y: Float[Array, " D"] ) -> Float[Array, "1"]: - state = self.network.init(self.key, x) - xt = self.network.apply(state, x) - yt = self.network.apply(state, y) + xt = self.network(x) + yt = self.network(y) return self.base_kernel(xt, yt) @@ -135,20 +146,25 @@ def __call__( feature_space_dim = 3 -class Network(nn.Module): - """A simple MLP.""" +class Network(nnx.Module): + def __init__( + self, rngs: nnx.Rngs, *, input_dim: int, inner_dim: int, feature_space_dim: int + ) -> None: + self.layer1 = nnx.Linear(input_dim, inner_dim, rngs=rngs) + self.output_layer = nnx.Linear(inner_dim, feature_space_dim, rngs=rngs) + self.rngs = rngs - @nn.compact - def __call__(self, x): - x = nn.Dense(features=32)(x) - x = nn.relu(x) - x = nn.Dense(features=64)(x) - x = nn.relu(x) - x = nn.Dense(features=feature_space_dim)(x) + def __call__(self, x: jax.Array) -> jax.Array: + x = x.reshape((x.shape[0], -1)) + x = self.layer1(x) + x = jax.nn.relu(x) + x = self.output_layer(x).squeeze() return x -forward_linear = Network() +forward_linear = Network( + nnx.Rngs(123), feature_space_dim=feature_space_dim, inner_dim=32, input_dim=1 +) # %% [markdown] # ## Defining a model @@ -162,9 +178,7 @@ def __call__(self, x): active_dims=list(range(feature_space_dim)), lengthscale=jnp.ones((feature_space_dim,)), ) -kernel = DeepKernelFunction( - network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x -) +kernel = DeepKernelFunction(network=forward_linear, base_kernel=base_kernel) meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) @@ -202,7 +216,7 @@ def __call__(self, x): opt_posterior, history = gpx.fit( model=posterior, - objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)), + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=D, optim=optimiser, num_iters=800, diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index 1c58dc94e..a328f808c 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + # %% [markdown] # # Graph Kernels # @@ -154,7 +171,7 @@ # %% opt_posterior, training_history = gpx.fit_scipy( model=posterior, - objective=gpx.objectives.ConjugateMLL(negative=True), + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=D, ) diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 0d09c2852..e98d4c7b0 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + # %% [markdown] # # Introduction to Kernels @@ -213,6 +230,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821 # First we define our model, using the Matérn52 kernel, and construct our posterior *without* optimising the kernel hyperparameters: # %% +from gpjax.parameters import PositiveReal + mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52( lengthscale=jnp.array(0.1) @@ -221,9 +240,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821 prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) likelihood = gpx.likelihoods.Gaussian( - num_datapoints=D.n, obs_stddev=jnp.array(1e-3) + num_datapoints=D.n, obs_stddev=PositiveReal(value=jnp.array(1e-3), tag="Static") ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value -likelihood = likelihood.replace_trainable(obs_stddev=False) no_opt_posterior = prior * likelihood @@ -231,14 +249,13 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821 # We can then optimise the hyperparameters by minimising the negative log marginal likelihood of the data: # %% -negative_mll = gpx.objectives.ConjugateMLL(negative=True) -negative_mll(no_opt_posterior, train_data=D) +gpx.objectives.conjugate_mll(no_opt_posterior, data=D) # %% opt_posterior, history = gpx.fit_scipy( model=no_opt_posterior, - objective=negative_mll, + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=D, ) @@ -499,17 +516,20 @@ def plot_ribbon(ax, x, dist, color): posterior = prior * likelihood + # %% [markdown] # With our model constructed, let's now fit it to the data, by minimising the negative log # marginal likelihood of the data: + # %% -negative_mll = gpx.objectives.ConjugateMLL(negative=True) -negative_mll(posterior, train_data=D) +def loss(posterior, data): + return -gpx.objectives.conjugate_mll(posterior, data) + opt_posterior, history = gpx.fit( model=posterior, - objective=negative_mll, + objective=loss, train_data=D, optim=ox.adamw(learning_rate=1e-2), num_iters=500, diff --git a/docs/examples/pytrees.md b/docs/examples/pytrees.md deleted file mode 100644 index 3f63004a7..000000000 --- a/docs/examples/pytrees.md +++ /dev/null @@ -1,632 +0,0 @@ -# 🌳 GPJax Module - -`GPJax` **represents all objects as JAX -[_PyTrees_](https://jax.readthedocs.io/en/latest/pytrees.html)**, giving - -- A simple API with a **TensorFlow / PyTorch feel** ... -- ... whilst **fully compatible** with JAX's functional paradigm ... -- ... And **works out of the box** (no filtering) with JAX's transformations - such as `grad`. - -We achieve this through providing a base `Module` abstraction to cleanly -handle parameter trainability and optimising transformations of JAX models. - - - -# Gaussian process objects as data - -Our abstraction is inspired by the [Equinox](https://github.com/patrick-kidger/equinox) library and aims to offer a -Bayesian/Gaussian process extension to their neural network abstractions. Our -approach enables users to easily create Python classes and -define parameter domains and training statuses for optimisation within a -single model object. This object can be used with JAX autogradients without -any filtering. - -The fundamental concept is to describe every model object as an immutable tree -structure, where every method is a function of the state (represented by the -tree's leaves). - -To help you understand how to create custom objects in GPJax, we will look at -a simple example in the following section. - - -## The RBF kernel - - -The kernel in a Gaussian process model is a mathematical function that -defines the covariance structure between data points, allowing us to model -complex relationships and make predictions based on the observed data. The -radial basis function (RBF, or _squared exponential_) kernel is a popular -choice. For any pair of vectors $`x, y \in \mathbb{R}^d`$, its form is -given by - -```math -k(x, y) = \sigma^2\exp\left(\frac{\lVert -x-y\rVert_{2}^2}{2\ell^2} \right) -``` - -where $`\sigma^2\in\mathbb{R}_{>0}`$ is a -variance parameter and $`\ell^2\in\mathbb{R}_{>0}`$ a lengthscale parameter. -Terming the evaluation of $`k(x, y)`$ the _covariance_, we can represent -this object as a Python `dataclass` as follows: - - -```python -import jax -import jax.numpy as jnp -from dataclasses import dataclass, field - - -@dataclass -class RBF: - lengthscale: float = field(default=1.0) - variance: float = field(default=1.0) - - def covariance(self, x: float, y: float) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) -``` - - -Here, the Python `dataclass` is a class that simplifies the process of -creating classes that primarily store data. It reduces boilerplate code and -provides convenient methods for initialising and representing the data. An -equivalent class could be written as: - -```python -class RBF: - def __init__(self, lengthscale: float = 1.0, variance: float = 1.0) -> None: - self.lengthscale = lengthscale - self.variance = variance - - def covariance(self, x: float, y: float) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x-y) / self.lengthscale)**2) -``` - - -To establish some terminology, within the above RBF `dataclass`, we refer to -the lengthscale and variance as _fields_. Further, the `RBF.covariance` is a -_method_. So far so good. However, if we wanted to take the gradient of -the kernel with respect to its parameters $`\nabla_{\ell, \sigma^2} k(1.0, 2.0; -\ell, \sigma^2)`$ at inputs $`x=1.0`$ and $`y=2.0`$, then we encounter a problem: - -```python -kernel = RBF() - -try: - jax.grad(lambda kern: kern.covariance(1.0, 2.0))(kernel) -except TypeError as e: - print(e) -``` -```console -Argument 'RBF(lengthscale=1.0, variance=1.0)' of type is not a valid JAX type. -``` - -This issues arises as the object we have defined is not yet -compatible with JAX. To achieve this we must consider [JAX's _PyTree_](https://jax.readthedocs.io/en/latest/pytrees.html) -abstraction. - - -## PyTrees - -JAX PyTrees are a powerful tool in the JAX library that enable users to work -with complex data structures in a way that is efficient, flexible, and easy to -use. A PyTree is a data structure that is composed of other data -structures, and it can be thought of as a tree where each 'node' is either a -leaf (a simple data structure) or another PyTree. By default, the set -of 'node' types that are regarded a PyTree are Python lists, tuples, and -dicts. - -For instance: - -```python -tree = [3.14, {"Monte": object(), "Carlo": False}] -print(tree) -``` -```console -[3.14, {'Monte': , 'Carlo': False}] -``` -is a PyTree with structure - -```python -import jax.tree_util as jtu - -print(jtu.tree_structure(tree)) -``` -```console -PyTreeDef([*, {'Carlo': *, 'Monte': *}]) -``` -with the following leaves - -```python -print(jtu.tree_leaves(tree)) -``` -```console -[3.14, False, ] -``` - -Consider a second example, a _PyTree of JAX arrays_ - -```python -tree = ( - jnp.array([1.0, 2.0, 3.0]), - jnp.array([4.0, 5.0, 6.0]), - jnp.array([7.0, 8.0, 9.0]), -) -``` - - -You can use this template to perform various operations on the data, such as -applying a function to each leaf of the PyTree. - - - -For example, suppose you want to square each element of the arrays. You can -then apply this using the `tree_map` function from the `jax.tree_util` module: - - -```python -print(jtu.tree_map(lambda x: x**2, tree)) -``` -```console -(Array([1., 4., 9.], dtype=float32), Array([16., 25., 36.], dtype=float32), Array([49., 64., 81.], dtype=float32)) -``` - -In this example, the PyTree makes it easy to apply a function to each leaf of -a complex data structure, without having to manually traverse the data -structure and handle each leaf individually. JAX PyTrees, therefore, are a -powerful tool that can simplify many tasks in machine learning and scientific -computing. As such, most JAX functions operate over _PyTrees of JAX arrays_. -For instance, `jax.lax.scan`, accepts as input and produces as output a -PyTree of JAX arrays. - -Another key advantages of using JAX PyTrees is that they are designed to work -efficiently with JAX's automatic differentiation and compilation features. For -example, suppose you have a function that takes a PyTree as input and returns -a scalar value: - - -```python -def sum_squares(x): - return jnp.sum(x[0] ** 2 + x[1] ** 2 + x[2] ** 2) - -sum_squares(tree) -``` -```console -Array(285., dtype=float32) -``` - -You can use JAX's `grad` function to automatically compute the gradient of -this function with respect to the input PyTree: - -```python -gradient = jax.grad(sum_squares)(tree) -print(gradient) -``` -```console -(Array([2., 4., 6.], dtype=float32), Array([ 8., 10., 12.], dtype=float32), Array([14., 16., 18.], dtype=float32)) -``` - -This computes the gradient of the `sum_squares` function with respect to the -input PyTree, and returns a new PyTree with the same shape and structure. - -JAX PyTrees are also designed to be highly extensible, where custom types can be readily registered through a global registry with the -values of such traversed recursively (i.e., as a tree!). This means we can -define our own custom data structures and use them as PyTrees. This is the -functionality that we exploit, whereby we construct all Gaussian process -models via a tree-structure through our `Module` object. - - -# Module - -Our design, first and foremost, minimises additional abstractions on top of -standard JAX: everything is just PyTrees and transformations on PyTrees, and -secondly, provides full compatibility with the main JAX library itself, -enhancing integrability with the broader ecosystem of third-party JAX -libraries. To achieve this, our core idea is represent all model objects via -an immutable PyTree. Here the leaves of the PyTree represent the parameters -that are to be trained, and we describe their domain and trainable status as -`dataclass` metadata. - -For our RBF kernel we have two parameters; the lengthscale and the variance. -Both of these have positive domains, and by default we want to train both of -these parameters. To encode this we use a `param_field`, where we can define -the domain of both parameters via a `Softplus` bijector (that restricts them -to the positive domain), and set their trainable status to `True`. - -```python -import tensorflow_probability.substrates.jax.bijectors as tfb -from gpjax.base import Module, param_field - - -@dataclass -class RBF(Module): - lengthscale: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True) - variance: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True) - - def covariance(self, x: jax.Array, y: jax.Array) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) -``` - - -Here `param_field` is just a special type of `dataclasses.field`. As such the -following: - -```python -param_field(1.0, bijector= tfb.Identity(), trainable=False) -``` - -is equivalent to the following `dataclasses.field` - -```python -field(default=1.0, metadata={"trainable": False, "bijector": tfb.Identity()}) -``` - - -By default unmarked leaf attributes default to an `Identity` bijector and -trainablility set to `True`. - - - -### Replacing values -For consistency with JAX’s functional programming principles, `Module` -instances are immutable. PyTree nodes can be changed out-of-place via the -`replace` method. - -```python -kernel = RBF() -kernel = kernel.replace(lengthscale=3.14) # Update e.g., the lengthscale. -print(kernel) -``` -```console -RBF(lengthscale=3.14, variance=1.0) -``` - -## Transformations 🤖 - -Use `constrain` / `unconstrain` to return a `Module` with each parameter's -bijector `forward` / `inverse` operation applied! - -```python -# Transform kernel to unconstrained space -unconstrained_kernel = kernel.unconstrain() -print(unconstrained_kernel) - -# Transform kernel back to constrained space -kernel = unconstrained_kernel.constrain() -print(kernel) -``` -```console -RBF(lengthscale=Array(3.0957527, dtype=float32), variance=Array(0.54132485, dtype=float32)) -RBF(lengthscale=Array(3.14, dtype=float32), variance=Array(1., dtype=float32)) -``` - -Default transformations can be replaced on an instance via the -`replace_bijector` method. - -```python -new_kernel = kernel.replace_bijector(lengthscale=tfb.Identity()) - -# Transform kernel to unconstrained space -unconstrained_kernel = new_kernel.unconstrain() -print(unconstrained_kernel) - -# Transform kernel back to constrained space -new_kernel = unconstrained_kernel.constrain() -print(new_kernel) -``` -```console -RBF(lengthscale=Array(3.14, dtype=float32), variance=Array(0.54132485, dtype=float32)) -RBF(lengthscale=Array(3.14, dtype=float32), variance=Array(1., dtype=float32)) -``` - -## Trainability 🚂 - -Recall the example earlier, where we wanted to take the gradient of the kernel -with respect to its parameters $`\nabla_{\ell, \sigma^2} k(1.0, 2.0; \ell,\sigma^2)`$ at inputs $`x=1.0`$ and $`y=2.0`$. We can now confirm we can do this -with the new `Module`. - -```python -kernel = RBF() - -jax.grad(lambda kern: kern.covariance(1.0, 2.0))(kernel) -``` -```console -RBF(lengthscale=Array(0.60653067, dtype=float32, weak_type=True), variance=Array(0.60653067, dtype=float32, weak_type=True)) -``` - -During gradient learning of models, it can sometimes be useful to fix certain -parameters during the optimisation routine. For this, JAX provides a -`stop_gradient` operand to prevent the flow of gradients during forward or -reverse-mode automatic differentiation, as illustrated below for a function -$`f(x) = x^2`$. - -```python -from jax import lax - - -def f(x): - x = lax.stop_gradient(x) - return x**2 - - -jax.grad(f)(1.0) -``` -```console -Array(0., dtype=float32, weak_type=True) -``` - -We see that gradient return is `0.0` instead of `2.0` due to the stopping of -the gradient. Analogous to this, we provide this functionality to gradient -flows on our `Module` class, via a `stop_gradient` method. - -Setting a (leaf) parameter's trainability to false can be achieved via the -`replace_trainable` method. - -```python - -kernel = RBF() -kernel = kernel.replace_trainable(lengthscale=False) - -jax.grad(lambda kern: kern.stop_gradient().covariance(1.0, 2.0))(kernel) -``` -```console -RBF(lengthscale=Array(0., dtype=float32, weak_type=True), variance=Array(0.60653067, dtype=float32, weak_type=True)) -``` - -As expected, the gradient is zero for the lengthscale parameter. - - -## Static fields - -In machine learning, initialising model parameters from random points is a -common practice because it helps to break the symmetry in the model and allows -the optimization algorithm to explore different regions of the parameter -space. - -We could cleanly do this within the RBF class via a `post_init` method as -follows: - -```python -import jax.random as jr -import tensorflow_probability.substrates.jax.distributions as tfd -from dataclasses import field - -@dataclass -class RBF(Module): - lengthscale: float = param_field( - init=False, bijector=tfb.Softplus(), trainable=True - ) - variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True) - key: jr.KeyArray = field(default_factory = lambda: jr.PRNGKey(42)) - # Note, for Python <3.11 you may use the following: - # key: jr.KeyArray = jr.PRNGKey(42) - - def __post_init__(self): - # Split key into two keys - key1, key2 = jr.split(self.key) - - # Sample from Gamma distribution to initialise lengthscale and variance - self.lengthscale = tfd.Gamma(1.0, 0.1).sample(seed=key1) - self.variance = tfd.Gamma(1.0, 0.1).sample(seed=key2) - - def covariance(self, x: jax.Array, y: jax.Array) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) - - -kernel = RBF() -print(kernel) -``` -```console -RBF(lengthscale=Array(0.54950446, dtype=float32), variance=Array(2.8077831, dtype=float32), key=Array([ 0, 42], dtype=uint32)) -``` - -So far so good. But however, if we now took our gradient again - -```python -try: - jax.grad(lambda kern: kern.stop_gradient().covariance(1.0, 2.0))(kernel) -except TypeError as e: - print(e) -``` -```console -grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got uint32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True. -``` - -We observe that we get a TypeError because the key is not differentiable. We -can fix this by using a `static_field` for defining our key attribute. - -```python -from gpjax.base import static_field - -@dataclass -class RBF(Module): - lengthscale: float = param_field( - init=False, bijector=tfb.Softplus(), trainable=True - ) - variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True) - key: jr.KeyArray = static_field(default_factory=lambda: jr.PRNGKey(42)) - - def __post_init__(self): - # Split key into two keys - key1, key2 = jr.split(self.key) - - # Sample from Gamma distribution to initialise lengthscale and variance - self.lengthscale = tfd.Gamma(1.0, 0.1).sample(seed=key1) - self.variance = tfd.Gamma(1.0, 0.1).sample(seed=key2) - - def covariance(self, x: jax.Array, y: jax.Array) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) - - -fixed_kernel = RBF() -print(fixed_kernel) -``` -```console -RBF(lengthscale=Array(0.54950446, dtype=float32), variance=Array(2.8077831, dtype=float32), key=Array([ 0, 42], dtype=uint32)) -``` - -So we get the same class as before. But this time - -```python -jax.grad(lambda kern: kern.stop_gradient().covariance(1.0, 2.0))(fixed_kernel) -``` -```console -RBF(lengthscale=Array(3.230818, dtype=float32), variance=Array(0.19092491, dtype=float32), key=Array([ 0, 42], dtype=uint32)) -``` - -What happened to get the result we wanted? The difference lies in the -treatment of the key attribute as a PyTree leaf in the first example, which -caused the gradient computation to fail. Examining the flattened PyTree's of -both cases: - -```python -print(jax.tree_util.tree_flatten(fixed_kernel)) -print(jax.tree_util.tree_flatten(kernel)) -``` -```console -([Array(0.54950446, dtype=float32), Array(2.8077831, dtype=float32)], PyTreeDef(CustomNode(RBF[(['lengthscale', 'variance'], [('key', Array([ 0, 42], dtype=uint32))])], [*, *]))) -([Array([ 0, 42], dtype=uint32), Array(0.54950446, dtype=float32), Array(2.8077831, dtype=float32)], PyTreeDef(CustomNode(RBF[(['key', 'lengthscale', 'variance'], [])], [*, *, *]))) -``` - -We see that assigning `static_field` tells JAX not to regard the attribute as -leaf of the PyTree. - - -## Metadata - - -To determine the parameter domain and trainable statuses of each parameter, -the `Module` stores metadata for each leaf of the PyTree. This metadata is -defined through a `dataclasses.field`. Thus, under the hood, we can define our -`RBF` kernel object (equivalent to before) manually as follows: - -```python -from dataclasses import field - - -@dataclass -class RBF(Module): - lengthscale: float = field( - default=1.0, metadata={"bijector": tfb.Softplus(), "trainable": True} - ) - variance: float = field( - default=1.0, metadata={"bijector": tfb.Softplus(), "trainable": True} - ) - - def covariance(self, x: jax.Array, y: jax.Array) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) -``` - -Here the `metadata` in the `dataclasses.field`, defines the metadata we -associate with each PyTree leaf. This metadata can be a dictionary of any -attributes we wish to store about each leaf. For example, we could extend this -further by introducing a `name` attribute: - -```python -from dataclasses import field - - -@dataclass -class RBF(Module): - lengthscale: float = field( - default=1.0, - metadata={"bijector": tfb.Softplus(), "trainable": True, "name": "lengthscale"}, - ) - variance: float = field( - default=1.0, - metadata={"bijector": tfb.Softplus(), "trainable": True, "name": "variance"}, - ) - - def covariance(self, x: jax.Array, y: jax.Array) -> jax.Array: - return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2) -``` - -We can trace the metadata defined on the class via `meta_leaves`. - -```python -from gpjax.base import meta_leaves - -rbf = RBF() - -meta_leaves(rbf) -``` -```console -[({'bijector': , - 'trainable': True, - 'name': 'lengthscale'}, - 1.0), - ({'bijector': , - 'trainable': True, - 'name': 'variance'}, - 1.0)] -``` - -Similar to `jax.tree_utils.tree_leaves`, this function returns a flattened -PyTree. However, instead of just the values, it returns a list of tuples that -contain both the metadata and value of each PyTree leaf. This traced metadata -can be utilised for applying maps (how `constrain`, `unconstrain`, -`stop_gradient` work), as described in the next section. - - -## Metamap - - -The `constrain`, `unconstrain`, and `stop_gradient` methods on the `Module` -use a `meta_map` function under the hood. This function enables us to apply -metadata functions to the PyTree leaves, making it a powerful tool. - -To achieve this, the function involves the same tracing as `meta_leaves` to -create a flattened list of tuples consisting of (metadata, leaf value). -However, it also allows us to apply a function to this list and return a new -transformed PyTree, as demonstrated in the examples that follow. - - -### Filter example: - -A `meta_map` works similarly to `jax.tree_utils.tree_map`. However, it differs in that it allows us to define a function that operates on -the tuple (metadata, leaf value). For example, we could use a function to -filter based on a `name` attribute. - - -```python -from gpjax.base import meta_map - - -def filter_lengthscale(meta_leaf): - meta, leaf = meta_leaf - if meta.get("name", None) == "lengthscale": - return 3.14 - else: - return leaf - - -print(meta_map(filter_lengthscale, rbf)) -``` -```console -RBF(lengthscale=3.14, variance=1.0) -``` - -### How `constrain` works: - - -To apply a constrain, we filter on the attribute "bijector", and apply a -forward transformation to the PyTree leaf: - -```python -# This is how constrain works! ⛏ -def _apply_constrain(meta_leaf): - meta, leaf = meta_leaf - - if meta is None: - return leaf - - return meta.get("bijector", tfb.Identity()).forward(leaf) - - -meta_map(_apply_constrain, rbf) -``` -```console -RBF(lengthscale=Array(1.3132617, dtype=float32), variance=Array(1.3132617, dtype=float32)) -``` - -As expected, we find the same result as calling `rbf.constrain()`. diff --git a/docs/examples/regression.py b/docs/examples/regression.py index a93c8e1d0..129aa42bf 100644 --- a/docs/examples/regression.py +++ b/docs/examples/regression.py @@ -17,7 +17,7 @@ # %% [markdown] # # Regression # -# In this notebook we demonstate how to fit a Gaussian process regression model. +# In this notebook we demonstrate how to fit a Gaussian process regression model. # %% # Enable Float64 for more stable matrix inversions. diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index c1d0958aa..29fe620b2 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -29,7 +29,6 @@ config.update("jax_enable_x64", True) -from jax import jit import jax.random as jr import jax.numpy as jnp from jaxtyping import install_import_hook diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 12d6b3369..a8c1aa5b9 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -40,7 +40,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.8.0" +__version__ = "0.9.0" __all__ = [ "base", diff --git a/gpjax/fit.py b/gpjax/fit.py index fa6d414c3..2cee3aef8 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -60,37 +60,38 @@ def fit( # noqa: PLR0913 Optimisers used here should originate from Optax. Example: + ```pycon >>> import jax.numpy as jnp >>> import jax.random as jr >>> import optax as ox >>> import gpjax as gpx - >>> from gpjax.parameters import Parameter, Static - ... + >>> from gpjax.parameters import PositiveReal, Static + >>> >>> # (1) Create a dataset: >>> X = jnp.linspace(0.0, 10.0, 100)[:, None] >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape) >>> D = gpx.Dataset(X, y) >>> # (2) Define your model: >>> class LinearModel(nnx.Module): - ... def __init__(self, weight: float, bias: float): - ... self.weight = Parameter(weight) - ... self.bias = Static(bias) - ... - ... def __call__(self, x): - ... return self.weight.value * x + self.bias.value - ... + >>> def __init__(self, weight: float, bias: float): + >>> self.weight = PositiveReal(weight) + >>> self.bias = Static(bias) + >>> + >>> def __call__(self, x): + >>> return self.weight.value * x + self.bias.value + >>> >>> model = LinearModel(weight=1.0, bias=1.0) - ... + >>> >>> # (3) Define your loss function: >>> def mse(model, data): - ... pred = model(data.X) - ... return jnp.mean((pred - data.y) ** 2) - ... + >>> pred = model(data.X) + >>> return jnp.mean((pred - data.y) ** 2) + >>> >>> # (4) Train! >>> trained_model, history = gpx.fit( - ... model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000 - ... ) - + >>> model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000 + >>> ) + ``` Args: model (Model): The model Module to be optimised. diff --git a/gpjax/gps.py b/gpjax/gps.py index fea28cdbc..319d6b3c0 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -175,12 +175,12 @@ def __mul__(self, other): # noqa: F811 Example: ```pycon - >>> import gpjax as gpx - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> prior * likelihood + >>> import gpjax as gpx + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> prior * likelihood ``` Args: other (Likelihood): The likelihood distribution of the observed dataset. @@ -237,12 +237,12 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> kernel = gpx.kernels.RBF() - >>> mean_function = gpx.mean_functions.Zero() - >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) - >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> kernel = gpx.kernels.RBF() + >>> mean_function = gpx.mean_functions.Zero() + >>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel) + >>> prior.predict(jnp.linspace(0, 1, 100)[:, None]) ``` Args: @@ -293,17 +293,17 @@ def sample_approx( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> key = jr.PRNGKey(123) - >>> - >>> meanf = gpx.mean_functions.Zero() - >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) - >>> - >>> sample_fn = prior.sample_approx(10, key) - >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> key = jr.PRNGKey(123) + >>> + >>> meanf = gpx.mean_functions.Zero() + >>> kernel = gpx.kernels.RBF(n_dims=1) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) + >>> + >>> sample_fn = prior.sample_approx(10, key) + >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) ``` Args: @@ -433,16 +433,16 @@ class ConjugatePosterior(AbstractPosterior[P, GL]): Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> prior = gpx.gps.Prior( - mean_function = gpx.mean_functions.Zero(), - kernel = gpx.kernels.RBF() - ) - >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) - >>> - >>> posterior = prior * likelihood + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> prior = gpx.gps.Prior( + mean_function = gpx.mean_functions.Zero(), + kernel = gpx.kernels.RBF() + ) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> + >>> posterior = prior * likelihood ``` """ @@ -474,17 +474,17 @@ def predict( Example: ```pycon - >>> import gpjax as gpx - >>> import jax.numpy as jnp - >>> - >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) - >>> ytrain = jnp.sin(xtrain) - >>> D = gpx.Dataset(X=xtrain, y=ytrain) - >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) - >>> - >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) - >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) - >>> predictive_dist = posterior(xtest, D) + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> D = gpx.Dataset(X=xtrain, y=ytrain) + >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) + >>> + >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) + >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) + >>> predictive_dist = posterior(xtest, D) ``` Args: diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 9da0a24bf..3844ebfda 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -15,6 +15,7 @@ """JaxKern.""" +from gpjax.kernels import stationary from gpjax.kernels.approximations import RFF from gpjax.kernels.base import ( AbstractKernel, @@ -69,4 +70,5 @@ "White", "BasisFunctionComputation", "RFF", + "stationary", ] diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 95aafd4c6..8d7715281 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -19,6 +19,7 @@ from cola.ops.operators import ( Diagonal, Identity, + Product, ) from jax import vmap import jax.numpy as jnp @@ -29,12 +30,13 @@ from gpjax.typing import Array K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821 +ConstantDiagonalType = Product class ConstantDiagonalKernelComputation(AbstractKernelComputation): r"""Computation engine for constant diagonal kernels.""" - def _gram(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal: + def gram(self, kernel: K, x: Float[Array, "N D"]) -> Product: value = kernel(x[0], x[0]) dtype = value.dtype shape = (x.shape[0], x.shape[0]) diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index eeb734b41..b4b6d9d47 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -34,7 +34,7 @@ class DiagonalKernelComputation(AbstractKernelComputation): a diagonal Gram matrix. """ - def _gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator: + def gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator: return PSD(Diagonal(diag=vmap(lambda x: kernel(x, x))(x))) def _cross_covariance( diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index b10e1e6bf..db2b32d9b 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -106,7 +106,7 @@ def __call__( # TODO not consistent with general kernel interface *, S, **kwargs, - ) -> Float[Array, ""]: + ): Kxx = (jax_gather_nd(self.eigenvectors.value, x) * S.squeeze()) @ jnp.transpose( jax_gather_nd(self.eigenvectors.value, y) ) # shape (n,n) diff --git a/gpjax/kernels/stationary/base.py b/gpjax/kernels/stationary/base.py index 14563bdc9..d89f9d0d8 100644 --- a/gpjax/kernels/stationary/base.py +++ b/gpjax/kernels/stationary/base.py @@ -72,13 +72,7 @@ def __init__( """ super().__init__(active_dims, n_dims, compute_engine) - - _check_lengthscale(lengthscale) - - lengthscale, self.n_dims = _check_lengthscale_dims_compat( - lengthscale, self.n_dims - ) - + self.n_dims = _validate_lengthscale(lengthscale, self.n_dims) if isinstance(lengthscale, nnx.Variable): self.lengthscale = lengthscale else: @@ -109,17 +103,57 @@ def spectral_density(self) -> tfd.Distribution: ) +def _validate_lengthscale( + lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]], + n_dims: tp.Union[int, None], +): + # Check that the lengthscale is a valid value. + _check_lengthscale(lengthscale) + + n_dims = _check_lengthscale_dims_compat(lengthscale, n_dims) + return n_dims + + def _check_lengthscale_dims_compat( lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]], n_dims: tp.Union[int, None], -) -> tuple[tp.Union[Lengthscale, nnx.Variable[Lengthscale]], tp.Union[int, None]]: +): + r"""Check that the lengthscale is compatible with n_dims. + + If possible, infer the number of input dimensions from the lengthscale. + """ + + if isinstance(lengthscale, nnx.Variable): + return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims) + + lengthscale = jnp.asarray(lengthscale) + ls_shape = jnp.shape(lengthscale) + + if ls_shape == (): + return n_dims + elif ls_shape != () and n_dims is None: + return ls_shape[0] + elif ls_shape != () and n_dims is not None: + if ls_shape != (n_dims,): + raise ValueError( + "Expected `lengthscale` to be compatible with the number " + f"of input dimensions. Got `lengthscale` with shape {ls_shape}, " + f"but the number of input dimensions is {n_dims}." + ) + return n_dims + + +def _check_lengthscale_dims_compat_old( + lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]], + n_dims: tp.Union[int, None], +): r"""Check that the lengthscale is compatible with n_dims. If possible, infer the number of input dimensions from the lengthscale. """ if isinstance(lengthscale, nnx.Variable): - return _check_lengthscale_dims_compat(lengthscale.value, n_dims) + return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims) lengthscale = jnp.asarray(lengthscale) ls_shape = jnp.shape(lengthscale) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 5ac8a3a31..7c1d584f5 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -7,33 +7,36 @@ import tensorflow_probability.substrates.jax.bijectors as tfb T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]]) +ParameterTag = str def transform( params: nnx.State, - params_bijection: tp.Dict[tp.Type, tfb.Bijector], + params_bijection: tp.Dict[str, tfb.Bijector], inverse: bool = False, ) -> nnx.State: r"""Transforms parameters using a bijector. Example: + ```pycon >>> from gpjax.parameters import PositiveReal, transform >>> import jax.numpy as jnp >>> import tensorflow_probability.substrates.jax.bijectors as tfb >>> from flax.experimental import nnx >>> params = nnx.State( - ... { - ... "a": PositiveReal(jnp.array([1.0])), - ... "b": PositiveReal(jnp.array([2.0])), - ... } - ... ) - >>> params_bijection = {PositiveReal: tfb.Softplus()} + >>> { + >>> "a": PositiveReal(jnp.array([1.0])), + >>> "b": PositiveReal(jnp.array([2.0])), + >>> } + >>> ) + >>> params_bijection = {'positive': tfb.Softplus()} >>> transformed_params = transform(params, params_bijection) >>> transformed_params["a"] - PositiveReal( - raw_value=Array([1.3132617], - dtype=float32) - ) + PositiveReal( + raw_value=Array([1.3132617], dtype=float32), + _tag='positive' + ) + ``` Args: @@ -46,7 +49,7 @@ def transform( """ def _inner(param: Parameter): - bijector = params_bijection.get(type(param), tfb.Identity()) + bijector = params_bijection.get(param._tag, tfb.Identity()) if inverse: transformed_value = bijector.inverse(param.value) @@ -71,17 +74,18 @@ class Parameter(nnx.Variable[T]): """ - def __init__(self, value: T, **kwargs): + def __init__(self, value: T, tag: ParameterTag, **kwargs): _check_is_arraylike(value) super().__init__(value=jnp.asarray(value), **kwargs) + self._tag = tag class PositiveReal(Parameter[T]): """Parameter that is strictly positive.""" - def __init__(self, value: T, **kwargs): - super().__init__(value=value, **kwargs) + def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs): + super().__init__(value=value, tag=tag, **kwargs) _check_is_positive(self.value) @@ -89,12 +93,15 @@ def __init__(self, value: T, **kwargs): class Real(Parameter[T]): """Parameter that can take any real value.""" + def __init__(self, value: T, tag: ParameterTag = "real", **kwargs): + super().__init__(value, tag, **kwargs) + class SigmoidBounded(Parameter[T]): """Parameter that is bounded between 0 and 1.""" - def __init__(self, value: T, **kwargs): - super().__init__(value=value, **kwargs) + def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs): + super().__init__(value=value, tag=tag, **kwargs) _check_in_bounds(self.value, 0.0, 1.0) @@ -102,27 +109,28 @@ def __init__(self, value: T, **kwargs): class Static(nnx.Variable[T]): """Static parameter that is not trainable.""" - def __init__(self, value: T, **kwargs): + def __init__(self, value: T, tag: ParameterTag = "static", **kwargs): _check_is_arraylike(value) - super().__init__(value=jnp.asarray(value), **kwargs) + super().__init__(value=jnp.asarray(value), tag=tag, **kwargs) + self._tag = tag class LowerTriangular(Parameter[T]): """Parameter that is a lower triangular matrix.""" - def __init__(self, value: T, **kwargs): - super().__init__(value=value, **kwargs) + def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs): + super().__init__(value=value, tag=tag, **kwargs) _check_is_square(self.value) _check_is_lower_triangular(self.value) DEFAULT_BIJECTION = { - PositiveReal: tfb.Softplus(), - Real: tfb.Identity(), - SigmoidBounded: tfb.Sigmoid(low=0.0, high=1.0), - LowerTriangular: tfb.FillTriangular(), + "positive": tfb.Softplus(), + "real": tfb.Identity(), + "sigmoid": tfb.Sigmoid(low=0.0, high=1.0), + "lower_triangular": tfb.FillTriangular(), } @@ -159,26 +167,3 @@ def _check_in_bounds(value: T, low: float, high: float): raise ValueError( f"Expected parameter value to be bounded between {low} and {high}. Got {value}." ) - - -# class TransformedParameter(Parameter[T]): -# """Parameter that is transformed using a bijector.""" - -# bj = tfb.Bijector - -# def create_value(self, value: T): -# return self.bj.inverse(value) - -# def get_value(self) -> T: -# return self.bj.forward(self.value) - -# def set_value(self, value: T): -# return self.replace(value=self.bj.forward(value)) - - -# class SigmoidBounded(TransformedParameter[T]): -# bj = tfb.Sigmoid(low=0.0, high=1.0) - - -# class SoftplusPositive(TransformedParameter[T]): -# bj = tfb.Softplus() diff --git a/poetry.lock b/poetry.lock index a19099846..5baa7d533 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. [[package]] name = "absl-py" @@ -215,6 +215,23 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "blackjax" +version = "0.9.6" +description = "Flexible and fast inference in Python" +optional = false +python-versions = "*" +files = [ + {file = "blackjax-0.9.6-py3-none-any.whl", hash = "sha256:d1c20dd15a63944a7b5c835bac4900aadf8630bedb0d7e51ab7fc63255eb0dd7"}, + {file = "blackjax-0.9.6.tar.gz", hash = "sha256:fb708f183d714750feb475fb87b8162fc1641309f30ee42fd38a5dec82733868"}, +] + +[package.dependencies] +fastprogress = ">=0.2.0" +jax = ">=0.3.13" +jaxlib = ">=0.3.10" +jaxopt = ">=0.4.2" + [[package]] name = "bleach" version = "6.1.0" @@ -570,6 +587,69 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "contourpy" +version = "1.2.1" +description = "Python library for calculating contours of 2D quadrilateral grids" +optional = false +python-versions = ">=3.9" +files = [ + {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, + {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, + {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"}, + {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"}, + {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"}, + {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"}, + {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"}, + {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"}, + {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"}, + {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"}, + {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"}, + {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"}, + {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"}, + {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"}, + {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"}, + {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"}, + {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"}, + {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"}, + {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"}, + {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"}, + {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, + {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, + {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, + {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, + {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"}, + {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"}, + {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"}, + {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"}, + {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"}, + {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"}, + {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"}, + {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"}, + {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"}, + {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, + {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, + {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, +] + +[package.dependencies] +numpy = ">=1.20" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] + [[package]] name = "coverage" version = "7.4.3" @@ -637,6 +717,21 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cycler" +version = "0.12.1" +description = "Composable style cycles" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + +[package.extras] +docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] +tests = ["pytest", "pytest-cov", "pytest-xdist"] + [[package]] name = "debugpy" version = "1.8.1" @@ -725,6 +820,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -842,6 +944,17 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastprogress" +version = "1.0.3" +description = "A nested progress with plotting options for fastai" +optional = false +python-versions = ">=3.6" +files = [ + {file = "fastprogress-1.0.3-py3-none-any.whl", hash = "sha256:6dfea88f7a4717b0a8d6ee2048beae5dbed369f932a368c5dd9caff34796f7c5"}, + {file = "fastprogress-1.0.3.tar.gz", hash = "sha256:7a17d2b438890f838c048eefce32c4ded47197ecc8ea042cecc33d3deb8022f5"}, +] + [[package]] name = "filelock" version = "3.13.1" @@ -891,6 +1004,71 @@ url = "https://github.com/google/flax.git" reference = "HEAD" resolved_reference = "ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf" +[[package]] +name = "fonttools" +version = "4.53.0" +description = "Tools to manipulate font files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fonttools-4.53.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:52a6e0a7a0bf611c19bc8ec8f7592bdae79c8296c70eb05917fd831354699b20"}, + {file = "fonttools-4.53.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:099634631b9dd271d4a835d2b2a9e042ccc94ecdf7e2dd9f7f34f7daf333358d"}, + {file = "fonttools-4.53.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e40013572bfb843d6794a3ce076c29ef4efd15937ab833f520117f8eccc84fd6"}, + {file = "fonttools-4.53.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:715b41c3e231f7334cbe79dfc698213dcb7211520ec7a3bc2ba20c8515e8a3b5"}, + {file = "fonttools-4.53.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:74ae2441731a05b44d5988d3ac2cf784d3ee0a535dbed257cbfff4be8bb49eb9"}, + {file = "fonttools-4.53.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:95db0c6581a54b47c30860d013977b8a14febc206c8b5ff562f9fe32738a8aca"}, + {file = "fonttools-4.53.0-cp310-cp310-win32.whl", hash = "sha256:9cd7a6beec6495d1dffb1033d50a3f82dfece23e9eb3c20cd3c2444d27514068"}, + {file = "fonttools-4.53.0-cp310-cp310-win_amd64.whl", hash = "sha256:daaef7390e632283051e3cf3e16aff2b68b247e99aea916f64e578c0449c9c68"}, + {file = "fonttools-4.53.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a209d2e624ba492df4f3bfad5996d1f76f03069c6133c60cd04f9a9e715595ec"}, + {file = "fonttools-4.53.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f520d9ac5b938e6494f58a25c77564beca7d0199ecf726e1bd3d56872c59749"}, + {file = "fonttools-4.53.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eceef49f457253000e6a2d0f7bd08ff4e9fe96ec4ffce2dbcb32e34d9c1b8161"}, + {file = "fonttools-4.53.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1f3e34373aa16045484b4d9d352d4c6b5f9f77ac77a178252ccbc851e8b2ee"}, + {file = "fonttools-4.53.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:28d072169fe8275fb1a0d35e3233f6df36a7e8474e56cb790a7258ad822b6fd6"}, + {file = "fonttools-4.53.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a2a6ba400d386e904fd05db81f73bee0008af37799a7586deaa4aef8cd5971e"}, + {file = "fonttools-4.53.0-cp311-cp311-win32.whl", hash = "sha256:bb7273789f69b565d88e97e9e1da602b4ee7ba733caf35a6c2affd4334d4f005"}, + {file = "fonttools-4.53.0-cp311-cp311-win_amd64.whl", hash = "sha256:9fe9096a60113e1d755e9e6bda15ef7e03391ee0554d22829aa506cdf946f796"}, + {file = "fonttools-4.53.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d8f191a17369bd53a5557a5ee4bab91d5330ca3aefcdf17fab9a497b0e7cff7a"}, + {file = "fonttools-4.53.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:93156dd7f90ae0a1b0e8871032a07ef3178f553f0c70c386025a808f3a63b1f4"}, + {file = "fonttools-4.53.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bff98816cb144fb7b85e4b5ba3888a33b56ecef075b0e95b95bcd0a5fbf20f06"}, + {file = "fonttools-4.53.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:973d030180eca8255b1bce6ffc09ef38a05dcec0e8320cc9b7bcaa65346f341d"}, + {file = "fonttools-4.53.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c4ee5a24e281fbd8261c6ab29faa7fd9a87a12e8c0eed485b705236c65999109"}, + {file = "fonttools-4.53.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd5bc124fae781a4422f61b98d1d7faa47985f663a64770b78f13d2c072410c2"}, + {file = "fonttools-4.53.0-cp312-cp312-win32.whl", hash = "sha256:a239afa1126b6a619130909c8404070e2b473dd2b7fc4aacacd2e763f8597fea"}, + {file = "fonttools-4.53.0-cp312-cp312-win_amd64.whl", hash = "sha256:45b4afb069039f0366a43a5d454bc54eea942bfb66b3fc3e9a2c07ef4d617380"}, + {file = "fonttools-4.53.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:93bc9e5aaa06ff928d751dc6be889ff3e7d2aa393ab873bc7f6396a99f6fbb12"}, + {file = "fonttools-4.53.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2367d47816cc9783a28645bc1dac07f8ffc93e0f015e8c9fc674a5b76a6da6e4"}, + {file = "fonttools-4.53.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:907fa0b662dd8fc1d7c661b90782ce81afb510fc4b7aa6ae7304d6c094b27bce"}, + {file = "fonttools-4.53.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e0ad3c6ea4bd6a289d958a1eb922767233f00982cf0fe42b177657c86c80a8f"}, + {file = "fonttools-4.53.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:73121a9b7ff93ada888aaee3985a88495489cc027894458cb1a736660bdfb206"}, + {file = "fonttools-4.53.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ee595d7ba9bba130b2bec555a40aafa60c26ce68ed0cf509983e0f12d88674fd"}, + {file = "fonttools-4.53.0-cp38-cp38-win32.whl", hash = "sha256:fca66d9ff2ac89b03f5aa17e0b21a97c21f3491c46b583bb131eb32c7bab33af"}, + {file = "fonttools-4.53.0-cp38-cp38-win_amd64.whl", hash = "sha256:31f0e3147375002aae30696dd1dc596636abbd22fca09d2e730ecde0baad1d6b"}, + {file = "fonttools-4.53.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7d6166192dcd925c78a91d599b48960e0a46fe565391c79fe6de481ac44d20ac"}, + {file = "fonttools-4.53.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef50ec31649fbc3acf6afd261ed89d09eb909b97cc289d80476166df8438524d"}, + {file = "fonttools-4.53.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f193f060391a455920d61684a70017ef5284ccbe6023bb056e15e5ac3de11d1"}, + {file = "fonttools-4.53.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba9f09ff17f947392a855e3455a846f9855f6cf6bec33e9a427d3c1d254c712f"}, + {file = "fonttools-4.53.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0c555e039d268445172b909b1b6bdcba42ada1cf4a60e367d68702e3f87e5f64"}, + {file = "fonttools-4.53.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5a4788036201c908079e89ae3f5399b33bf45b9ea4514913f4dbbe4fac08efe0"}, + {file = "fonttools-4.53.0-cp39-cp39-win32.whl", hash = "sha256:d1a24f51a3305362b94681120c508758a88f207fa0a681c16b5a4172e9e6c7a9"}, + {file = "fonttools-4.53.0-cp39-cp39-win_amd64.whl", hash = "sha256:1e677bfb2b4bd0e5e99e0f7283e65e47a9814b0486cb64a41adf9ef110e078f2"}, + {file = "fonttools-4.53.0-py3-none-any.whl", hash = "sha256:6b4f04b1fbc01a3569d63359f2227c89ab294550de277fd09d8fca6185669fa4"}, + {file = "fonttools-4.53.0.tar.gz", hash = "sha256:c93ed66d32de1559b6fc348838c7572d5c0ac1e4a258e76763a5caddd8944002"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["munkres", "pycairo", "scipy"] +lxml = ["lxml (>=4.0)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=15.1.0)"] +woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] + [[package]] name = "fsspec" version = "2024.2.0" @@ -1132,6 +1310,27 @@ qtconsole = ["qtconsole"] test = ["pickleshare", "pytest (<8)", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] +[[package]] +name = "ipywidgets" +version = "8.1.3" +description = "Jupyter interactive widgets" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ipywidgets-8.1.3-py3-none-any.whl", hash = "sha256:efafd18f7a142248f7cb0ba890a68b96abd4d6e88ddbda483c9130d12667eaf2"}, + {file = "ipywidgets-8.1.3.tar.gz", hash = "sha256:f5f9eeaae082b1823ce9eac2575272952f40d748893972956dc09700a6392d9c"}, +] + +[package.dependencies] +comm = ">=0.1.3" +ipython = ">=6.1.0" +jupyterlab-widgets = ">=3.0.11,<3.1.0" +traitlets = ">=4.3.1" +widgetsnbextension = ">=4.0.11,<4.1.0" + +[package.extras] +test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] + [[package]] name = "jax" version = "0.4.25" @@ -1204,6 +1403,23 @@ scipy = ">=1.9" cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +[[package]] +name = "jaxopt" +version = "0.8.3" +description = "Hardware accelerated, batchable and differentiable optimizers in JAX." +optional = false +python-versions = "*" +files = [ + {file = "jaxopt-0.8.3-py3-none-any.whl", hash = "sha256:4be2f82798393682529c9ca5046e5397ac6c8657b8acb6bf275e773b28df15b6"}, + {file = "jaxopt-0.8.3.tar.gz", hash = "sha256:4b06dfa6f915a4f3291699606245af6069371a48dc5c92d4c507840d62990646"}, +] + +[package.dependencies] +jax = ">=0.2.18" +jaxlib = ">=0.1.69" +numpy = ">=1.18.4" +scipy = ">=1.0.0" + [[package]] name = "jaxtyping" version = "0.2.28" @@ -1357,6 +1573,17 @@ files = [ {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"}, ] +[[package]] +name = "jupyterlab-widgets" +version = "3.0.11" +description = "Jupyter interactive widgets for JupyterLab" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyterlab_widgets-3.0.11-py3-none-any.whl", hash = "sha256:78287fd86d20744ace330a61625024cf5521e1c012a352ddc0a3cdc2348becd0"}, + {file = "jupyterlab_widgets-3.0.11.tar.gz", hash = "sha256:dd5ac679593c969af29c9bed054c24f26842baa51352114736756bc035deee27"}, +] + [[package]] name = "jupytext" version = "1.16.1" @@ -1386,6 +1613,119 @@ test-functional = ["jupytext[test]"] test-integration = ["ipykernel", "jupyter-server (!=2.11)", "jupytext[test-functional]", "nbconvert"] test-ui = ["calysto-bash"] +[[package]] +name = "kiwisolver" +version = "1.4.5" +description = "A fast implementation of the Cassowary constraint solver" +optional = false +python-versions = ">=3.7" +files = [ + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, + {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"}, + {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"}, + {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"}, + {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"}, + {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"}, + {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"}, + {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"}, + {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"}, + {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"}, + {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-win32.whl", hash = "sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3"}, + {file = "kiwisolver-1.4.5-cp37-cp37m-win_amd64.whl", hash = "sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93"}, + {file = "kiwisolver-1.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18"}, + {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc"}, + {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250"}, + {file = "kiwisolver-1.4.5-cp38-cp38-win32.whl", hash = "sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e"}, + {file = "kiwisolver-1.4.5-cp38-cp38-win_amd64.whl", hash = "sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"}, + {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"}, + {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"}, + {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"}, + {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"}, + {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"}, + {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"}, + {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"}, + {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"}, + {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, +] + [[package]] name = "markdown" version = "3.5.2" @@ -1425,6 +1765,22 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdown-katex" +version = "202112.1034" +description = "katex extension for Python Markdown" +optional = false +python-versions = ">=2.7" +files = [ + {file = "markdown-katex-202112.1034.tar.gz", hash = "sha256:27892f4cdd6763816f00e4187d0475500697c090aba16630ec4803a6564bf810"}, + {file = "markdown_katex-202112.1034-py2.py3-none-any.whl", hash = "sha256:9ccc5b4b37db7592cc3ea113d763fafe9ffd1b1587e2c217d6145e44a10b4f6d"}, +] + +[package.dependencies] +Markdown = {version = ">=3.0", markers = "python_version >= \"3.6\""} +pathlib2 = "*" +setuptools = "*" + [[package]] name = "markupsafe" version = "2.1.5" @@ -1494,6 +1850,58 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "matplotlib" +version = "3.9.0" +description = "Python plotting package" +optional = false +python-versions = ">=3.9" +files = [ + {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"}, + {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"}, + {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"}, + {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"}, + {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"}, + {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"}, + {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"}, + {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"}, + {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"}, + {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"}, + {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"}, + {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"}, + {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"}, + {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"}, + {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"}, + {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"}, + {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"}, + {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"}, + {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"}, + {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"}, + {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"}, + {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"}, + {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"}, + {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, +] + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +kiwisolver = ">=1.3.1" +numpy = ">=1.23" +packaging = ">=20.0" +pillow = ">=8" +pyparsing = ">=2.3.1" +python-dateutil = ">=2.7" + +[package.extras] +dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] + [[package]] name = "matplotlib-inline" version = "0.1.6" @@ -1791,7 +2199,7 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.3", markers = "python_version >= \"3.11\""}, ] @@ -1860,7 +2268,6 @@ files = [ {file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fbb160554e319f7b22ecf530a80a3ff496d38e8e07ae763b9e82fadfe96f273"}, {file = "msgpack-1.0.8-cp39-cp39-win32.whl", hash = "sha256:f9af38a89b6a5c04b7d18c492c8ccf2aee7048aff1ce8437c4683bb5a1df893d"}, {file = "msgpack-1.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:ed59dd52075f8fc91da6053b12e8c89e37aa043f8986efd89e61fae69dc1b011"}, - {file = "msgpack-1.0.8-py3-none-any.whl", hash = "sha256:24f727df1e20b9876fa6e95f840a2a2651e34c0ad147676356f4bf5fbb0206ca"}, {file = "msgpack-1.0.8.tar.gz", hash = "sha256:95c02b0e27e706e48d0e5426d1710ca78e0f0628d6e89d5b5a5b91a5f12274f3"}, ] @@ -1955,6 +2362,24 @@ files = [ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] +[[package]] +name = "networkx" +version = "3.3" +description = "Python package for creating and manipulating graphs and networks" +optional = false +python-versions = ">=3.10" +files = [ + {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, + {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, +] + +[package.extras] +default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2164,6 +2589,53 @@ files = [ {file = "paginate-0.5.6.tar.gz", hash = "sha256:5e6007b6a9398177a7e1648d04fdd9f8c9766a1a945bceac82f1929e8c78af2d"}, ] +[[package]] +name = "pandas" +version = "1.5.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23"}, + {file = "pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6"}, + {file = "pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf"}, + {file = "pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51"}, + {file = "pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5"}, + {file = "pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a"}, + {file = "pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9"}, + {file = "pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, +] +python-dateutil = ">=2.8.1" +pytz = ">=2020.1" + +[package.extras] +test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] + [[package]] name = "pandocfilters" version = "1.5.1" @@ -2190,6 +2662,20 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "pathlib2" +version = "2.3.7.post1" +description = "Object-oriented filesystem paths" +optional = false +python-versions = "*" +files = [ + {file = "pathlib2-2.3.7.post1-py2.py3-none-any.whl", hash = "sha256:5266a0fd000452f1b3467d782f079a4343c63aaa119221fbdc4e39577489ca5b"}, + {file = "pathlib2-2.3.7.post1.tar.gz", hash = "sha256:9fe0edad898b83c0c3e199c842b27ed216645d2e177757b2dd67384d4113c641"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "pathspec" version = "0.12.1" @@ -2215,6 +2701,92 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pillow" +version = "10.3.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, + {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, + {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, + {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, + {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, + {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, + {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, + {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, + {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, + {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, + {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, + {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, + {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, + {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, + {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, + {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + [[package]] name = "platformdirs" version = "4.2.0" @@ -2416,6 +2988,20 @@ files = [ {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"}, ] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pyproject-hooks" version = "1.0.0" @@ -2534,6 +3120,17 @@ files = [ jax = ">=0.4.7" typing-extensions = "*" +[[package]] +name = "pytz" +version = "2024.1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, +] + [[package]] name = "pywin32" version = "306" @@ -2569,6 +3166,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2576,8 +3174,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2594,6 +3200,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2601,6 +3208,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3057,6 +3665,27 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "seaborn" +version = "0.12.2" +description = "Statistical data visualization" +optional = false +python-versions = ">=3.7" +files = [ + {file = "seaborn-0.12.2-py3-none-any.whl", hash = "sha256:ebf15355a4dba46037dfd65b7350f014ceb1f13c05e814eda2c9f5fd731afc08"}, + {file = "seaborn-0.12.2.tar.gz", hash = "sha256:374645f36509d0dcab895cba5b47daf0586f77bfe3b36c97c607db7da5be0139"}, +] + +[package.dependencies] +matplotlib = ">=3.1,<3.6.1 || >3.6.1" +numpy = ">=1.17,<1.24.0 || >1.24.0" +pandas = ">=0.25" + +[package.extras] +dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest-cov", "pytest-xdist"] +docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] +stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"] + [[package]] name = "setuptools" version = "69.1.1" @@ -3391,6 +4020,25 @@ files = [ [package.extras] watchmedo = ["PyYAML (>=3.10)"] +[[package]] +name = "watermark" +version = "2.4.3" +description = "IPython magic function to print date/time stamps and various system information." +optional = false +python-versions = ">=3.7" +files = [ + {file = "watermark-2.4.3-py2.py3-none-any.whl", hash = "sha256:39be67f043d7fa0351537fa9b746bbf03ad1bb1ce3d3d84ec96eca954a5e1579"}, + {file = "watermark-2.4.3.tar.gz", hash = "sha256:43d0f7aafb5285af685adce08879f22b2e97be45e786bb93ea4c5e9478dd88e2"}, +] + +[package.dependencies] +importlib-metadata = ">=1.4" +ipython = ">=6.0" +setuptools = "*" + +[package.extras] +gpu = ["py3nvml (>=0.2)"] + [[package]] name = "wcwidth" version = "0.2.13" @@ -3413,6 +4061,17 @@ files = [ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.11" +description = "Jupyter interactive widgets for Jupyter Notebook" +optional = false +python-versions = ">=3.7" +files = [ + {file = "widgetsnbextension-4.0.11-py3-none-any.whl", hash = "sha256:55d4d6949d100e0d08b94948a42efc3ed6dfdc0e9468b2c4b128c9a2ce3a7a36"}, + {file = "widgetsnbextension-4.0.11.tar.gz", hash = "sha256:8b22a8f1910bfd188e596fe7fc05dcbd87e810c8a4ba010bdb3da86637398474"}, +] + [[package]] name = "xdoctest" version = "1.1.3" @@ -3454,4 +4113,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "ae66b9965fd2c0f2e7a82dd7b7bd653bb178d8e837eda500b06787abd424f021" +content-hash = "9304597b78d4fe7ce999c4d52c0fb5ad3e9a09e120f2eacb6e2e7980954d8e50" diff --git a/pyproject.toml b/pyproject.toml index 5a66d34a4..27004f287 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gpjax" -version = "0.8.0" +version = "0.9.0" description = "Gaussian processes in JAX." authors = [ "Thomas Pinder ", @@ -19,13 +19,14 @@ packages = [{ include = "gpjax" }] python = ">=3.10,<3.12" jax = ">=0.4.10" jaxlib = ">=0.4.10" -flax = {git = "https://github.com/google/flax.git"} +flax = { git = "https://github.com/google/flax.git" } optax = "^0.2.1" jaxtyping = "^0.2.10" tqdm = "^4.66.2" tensorflow-probability = "^0.20.0" beartype = "^0.16.1" cola-ml = "0.0.5" +jaxopt = "^0.8.3" [tool.poetry.group.dev.dependencies] ruff = "^0.3.0" @@ -50,18 +51,17 @@ mkdocs-jupyter = "^0.24.3" mkdocs-gen-files = "^0.5.0" mkdocs-literate-nav = "^0.6.0" mkdocs-git-authors-plugin = "^0.7.0" -# markdown-katex = "^202112.1034" -# pymdown-extensions = "^9.11" -# matplotlib = "^3.7.1" -# seaborn = "^0.12.2" -# networkx = "^3.0" -# jupytext = "^1.14.5" -# ipython = "^8.11.0" -# ipykernel = "^6.22.0" -# watermark = "^2.3.1" -# blackjax = "^0.9.6" -# ipywidgets = "^8.0.5" -# pandas = "^1.5.3" +markdown-katex = "^202112.1034" +matplotlib = "^3.7.1" +seaborn = "^0.12.2" +networkx = "^3.0" +jupytext = "^1.14.5" +ipython = "^8.11.0" +ipykernel = "^6.22.0" +watermark = "^2.3.1" +blackjax = "^0.9.6" +ipywidgets = "^8.0.5" +pandas = "^1.5.3" pymdown-extensions = "^10.7.1" nbconvert = "^7.16.2" diff --git a/tests/test_fit.py b/tests/test_fit.py index a2dfa2e25..9c5a98742 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -45,7 +45,7 @@ elbo, ) from gpjax.parameters import ( - Parameter, + PositiveReal, Static, ) from gpjax.typing import Array @@ -62,7 +62,7 @@ def test_fit_simple() -> None: class LinearModel(nnx.Module): def __init__(self, weight: float, bias: float): - self.weight = Parameter(weight) + self.weight = PositiveReal(weight) self.bias = Static(bias) def __call__(self, x): @@ -107,7 +107,7 @@ def test_fit_scipy_simple(): # Define linear model: class LinearModel(nnx.Module): def __init__(self, weight: float, bias: float): - self.weight = Parameter(weight) + self.weight = PositiveReal(weight) self.bias = Static(bias) def __call__(self, x): diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index da4d2d09f..7d968e5ca 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -31,7 +31,6 @@ Polynomial, ) from gpjax.parameters import ( - Parameter, PositiveReal, Static, ) @@ -94,8 +93,8 @@ def test_init_override_paramtype(kernel_request): continue new_params[param] = Static(value) - k = kernel(**new_params, variance=Parameter(variance)) - assert isinstance(k.variance, Parameter) + k = kernel(**new_params, variance=PositiveReal(variance)) + assert isinstance(k.variance, PositiveReal) for param in params.keys(): if param in ("degree", "order"): diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index c970753dc..ce490e65c 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -35,7 +35,6 @@ ) from gpjax.kernels.stationary.base import StationaryKernel from gpjax.parameters import ( - Parameter, PositiveReal, Static, ) @@ -104,12 +103,12 @@ def test_init_override_paramtype(kernel_request): for param, value in params.items(): new_params[param] = Static(value) - kwargs = {**new_params, "variance": Parameter(variance)} + kwargs = {**new_params, "variance": PositiveReal(variance)} if kernel != White: - kwargs["lengthscale"] = Parameter(lengthscale) + kwargs["lengthscale"] = PositiveReal(lengthscale) k = kernel(**kwargs) - assert isinstance(k.variance, Parameter) + assert isinstance(k.variance, PositiveReal) for param in params.keys(): assert isinstance(getattr(k, param), Static) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 94b1931f7..b0e492199 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -13,10 +13,9 @@ # limitations under the License. # ============================================================================== -from itertools import product from typing import ( Callable, - List, + Tuple, ) from jax import config @@ -31,7 +30,6 @@ import tensorflow_probability.substrates.jax.distributions as tfd from gpjax.likelihoods import ( - AbstractLikelihood, Bernoulli, Gaussian, Poisson, @@ -43,168 +41,73 @@ _initialise_key = jr.PRNGKey(123) -class BaseTestLikelihood: - """A base class that contains all tests applied on likelihoods.""" - - likelihood: AbstractLikelihood - static_fields: List[str] = ["num_datapoints"] - - def pytest_generate_tests(self, metafunc): - """This is called automatically by pytest.""" - - # function for pretty test name - def id_func(x): - return "-".join([f"{k}={v}" for k, v in x.items()]) - - # get arguments for the test function - funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) - - if funcarglist is None: - return - else: - # equivalent of pytest.mark.parametrize applied on the metafunction - metafunc.parametrize("fields", funcarglist, ids=id_func) - - @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") - def test_initialisation(self, fields: dict, n: int) -> None: - # Check that likelihood is a dataclass - - # Input fields as JAX arrays - fields = {k: jnp.array([v]) for k, v in fields.items()} - - # Initialise - likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n, **fields) - - # Check properties - for field, value in fields.items(): - assert getattr(likelihood, field).value == value - - @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") - def test_link_functions(self, n: int): - # Initialize likelihood with defaults - likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n) - - # Create input values - x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) - - # Test likelihood link function. - assert isinstance(likelihood.link_function, Callable) - assert isinstance(likelihood.link_function(x), tfd.Distribution) - - @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") - def test__call__(self, fields: dict, n: int): - # Input fields as JAX arrays - fields = {k: jnp.array([v]) for k, v in fields.items()} - - # Initialise - likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n, **fields) - - # Construct latent function distribution. - k1, k2 = jr.split(_initialise_key) - latent_mean = jr.uniform(k1, shape=(n,)) - latent_sqrt = jr.uniform(k2, shape=(n, n)) - latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) - latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) - - # Perform checks specific to the given likelihood - self._test_call_check(likelihood, latent_mean, latent_cov, latent_dist) - - @staticmethod - def _test_call_check(likelihood, latent_mean, latent_cov, latent_dist): - """Specific to each likelihood.""" - raise NotImplementedError - - -def prod(inp): - return [ - dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) - ] - - -class TestGaussian(BaseTestLikelihood): - likelihood = Gaussian - fields = prod({"obs_stddev": [0.1, 0.5, 1.0]}) - params = {"test_initialisation": fields, "test_call": fields} - static_fields = ["num_datapoints"] - - @staticmethod - def _test_call_check(likelihood: Gaussian, latent_mean, latent_cov, latent_dist): - # Test call method. - pred_dist = likelihood(latent_dist) - - # Check that the distribution is a MultivariateNormalFullCovariance. - assert isinstance(pred_dist, tfd.MultivariateNormalFullCovariance) - - # Check predictive mean and variance. - assert (pred_dist.mean() == latent_mean).all() - noise_matrix = ( - jnp.eye(likelihood.num_datapoints) * likelihood.obs_stddev.value**2 - ) - assert np.allclose( - pred_dist.scale_tril, jnp.linalg.cholesky(latent_cov + noise_matrix) - ) - - -class TestBernoulli(BaseTestLikelihood): - likelihood = Bernoulli - fields = prod({}) - params = {"test_initialisation": fields, "test_call": fields} - static_fields = ["num_datapoints"] - - @staticmethod - def _test_call_check( - likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist - ): - # Test call method. - pred_dist = likelihood(latent_dist) - - # Check that the distribution is a Bernoulli. - assert isinstance(pred_dist, tfd.Bernoulli) - - # Check predictive mean and variance. - - p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) - assert (pred_dist.mean() == p).all() - assert (pred_dist.variance() == p * (1.0 - p)).all() - - -class TestPoisson(BaseTestLikelihood): - likelihood = Poisson - fields = prod({}) - params = {"test_initialisation": fields, "test_call": fields} - static_fields = ["num_datapoints"] - - @staticmethod - def _test_call_check( - likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist - ): - # Test call method. - pred_dist = likelihood(latent_dist) - - # Check that the distribution is a Poisson. - assert isinstance(pred_dist, tfd.Poisson) - - # Check predictive mean and variance. - rate = jnp.exp(latent_mean) - assert (pred_dist.mean() == rate).all() - - -class TestAbstract(BaseTestLikelihood): - class DummyLikelihood(AbstractLikelihood): - def predict(self, dist: tfd.Distribution) -> tfd.Distribution: - return tfd.Normal(0.0, 1.0) - - def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - return tfd.MultivariateNormalDiag(loc=f) - - likelihood = DummyLikelihood - fields = prod({}) - params = {"test_initialisation": fields, "test_call": fields} - static_fields = ["num_datapoints"] - - @staticmethod - def _test_call_check( - likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist - ): - pred_dist = likelihood(latent_dist) - assert isinstance(pred_dist, tfd.Normal) +def _compute_latent_dist( + n: int, +) -> Tuple[ + tfd.MultivariateNormalFullCovariance, Float[Array, " N"], Float[Array, "N N"] +]: + k1, k2 = jr.split(_initialise_key) + latent_mean = jr.uniform(k1, shape=(n,)) + latent_sqrt = jr.uniform(k2, shape=(n, n)) + latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) + latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) + return latent_dist, latent_mean, latent_cov + + +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("obs_stddev", [0.1, 0.5, 1.0]) +def test_gaussian_likelihood(n: int, obs_stddev: float): + x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) + likelihood = Gaussian(num_datapoints=n, obs_stddev=obs_stddev) + + assert isinstance(likelihood.link_function, Callable) + assert isinstance(likelihood.link_function(x), tfd.Distribution) + + # Construct latent function distribution. + latent_dist, latent_mean, latent_cov = _compute_latent_dist(n) + pred_dist = likelihood(latent_dist) + assert isinstance(pred_dist, tfd.MultivariateNormalFullCovariance) + + # Check predictive mean and variance. + assert (pred_dist.mean() == latent_mean).all() + noise_matrix = jnp.eye(likelihood.num_datapoints) * likelihood.obs_stddev.value**2 + assert np.allclose( + pred_dist.scale_tril, jnp.linalg.cholesky(latent_cov + noise_matrix) + ) + + +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_bernoulli_likelihood(n: int): + x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) + likelihood = Bernoulli(num_datapoints=n) + + assert isinstance(likelihood.link_function, Callable) + assert isinstance(likelihood.link_function(x), tfd.Distribution) + + # Construct latent function distribution. + latent_dist, latent_mean, latent_cov = _compute_latent_dist(n) + pred_dist = likelihood(latent_dist) + assert isinstance(pred_dist, tfd.Bernoulli) + + # Check predictive mean and variance. + p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) + assert (pred_dist.mean() == p).all() + assert (pred_dist.variance() == p * (1.0 - p)).all() + + +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_poisson_likelihood(n: int): + x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) + likelihood = Poisson(num_datapoints=n) + + assert isinstance(likelihood.link_function, Callable) + assert isinstance(likelihood.link_function(x), tfd.Distribution) + + # Construct latent function distribution. + latent_dist, latent_mean, latent_cov = _compute_latent_dist(n) + pred_dist = likelihood(latent_dist) + assert isinstance(pred_dist, tfd.Poisson) + + # Check predictive mean and variance. + rate = jnp.exp(latent_mean) + assert (pred_dist.mean() == rate).all() diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 821901bd8..e62c4ceec 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,33 +1,60 @@ from flax.experimental import nnx import jax.numpy as jnp +import pytest from gpjax.parameters import ( DEFAULT_BIJECTION, + LowerTriangular, Parameter, PositiveReal, + Real, + SigmoidBounded, + Static, transform, ) -def test_transform(): +@pytest.mark.parametrize( + "param, value", + [ + (PositiveReal, 1.0), + (Real, 2.0), + (SigmoidBounded, 0.5), + ], +) +def test_transform(param, value): # Create mock parameters and bijectors params = nnx.State( { - "param1": PositiveReal(1.0), - "param2": Parameter(2.0), + "param1": param(value), + "param2": Parameter(2.0, tag="real"), } ) # Test forward transformation t_params = transform(params, DEFAULT_BIJECTION) - t_param1_expected = DEFAULT_BIJECTION[PositiveReal].forward(1.0) + t_param1_expected = DEFAULT_BIJECTION[params["param1"]._tag].forward(value) assert jnp.allclose(t_params["param1"].value, t_param1_expected) assert jnp.allclose(t_params["param2"].value, 2.0) # Test inverse transformation t_params = transform(params, DEFAULT_BIJECTION, inverse=True) - t_param1_expected = DEFAULT_BIJECTION[PositiveReal].inverse( + t_param1_expected = DEFAULT_BIJECTION[params["param1"]._tag].inverse( t_params["param1"].value ) - assert jnp.allclose(t_params["param1"].value, 1.0) + assert jnp.allclose(t_params["param1"].value, value) assert jnp.allclose(t_params["param2"].value, 2.0) + + +@pytest.mark.parametrize( + "param, tag", + [ + (PositiveReal(1.0), "positive"), + (Real(2.0), "real"), + (SigmoidBounded(0.5), "sigmoid"), + (Static(2.0), "static"), + (LowerTriangular(jnp.eye(2)), "lower_triangular"), + ], +) +def test_default_tags(param, tag): + assert param._tag == tag