Skip to content

Commit

Permalink
Add updated files after running pre-commit checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeyeung committed Oct 10, 2023
1 parent 19a58da commit 533347f
Show file tree
Hide file tree
Showing 16 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ package.json
package-lock.json
node_modules/

docs/api
docs/api
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Another way you can contribute to GPJax is through [issue
triaging](https://www.codetriage.com/what). This can include reproducing bug reports,
asking for vital information such as version numbers and reproduction instructions, or
identifying stale issues. If you would like to begin triaging issues, an easy way to get
started is to
started is to
[subscribe to GPJax on CodeTriage](https://www.codetriage.com/jaxgaussianprocesses/gpjax).

As a contributor to GPJax, you are expected to abide by our [code of
Expand Down
2 changes: 1 addition & 1 deletion docs/GOVERNANCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ maintainers or reach out over
-----

This file was adapted from
[BlackJAX](https://github.com/blackjax-devs/blackjax/blob/main/GOVERNANCE.md).
[BlackJAX](https://github.com/blackjax-devs/blackjax/blob/main/GOVERNANCE.md).
2 changes: 1 addition & 1 deletion gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import warnings

from beartype.typing import (
Literal,
Optional,
Union,
Literal,
)
import jax.numpy as jnp
from jaxtyping import (
Expand Down
2 changes: 1 addition & 1 deletion gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from beartype.typing import (
Any,
Generic,
Optional,
Tuple,
Generic,
TypeVar,
Union,
)
Expand Down
11 changes: 8 additions & 3 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

# from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass, field
from dataclasses import (
dataclass,
field,
)
from typing import overload

from beartype.typing import (
Expand All @@ -25,7 +28,6 @@
)
import cola
from cola.ops import Dense

import jax.numpy as jnp
from jax.random import (
PRNGKey,
Expand All @@ -47,7 +49,10 @@
ReshapedDistribution,
ReshapedGaussianDistribution,
)
from gpjax.kernels import RFF, White
from gpjax.kernels import (
RFF,
White,
)
from gpjax.kernels.base import AbstractKernel
from gpjax.likelihoods import (
AbstractLikelihood,
Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
)
from jax import vmap
import jax.numpy as jnp
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)

from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.typing import Array
Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/computations/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

import beartype.typing as tp
from jax import vmap
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.typing import Array
Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/computations/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
LinearOperator,
)
from jax import vmap
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)

from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.typing import Array
Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/stationary/matern32.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/stationary/matern52.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/stationary/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import param_field
Expand Down
5 changes: 4 additions & 1 deletion gpjax/kernels/stationary/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float, Num
from jaxtyping import (
Float,
Num,
)
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

Expand Down
5 changes: 4 additions & 1 deletion gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
param_field,
static_field,
)
from gpjax.distributions import GaussianDistribution, ReshapedDistribution
from gpjax.distributions import (
GaussianDistribution,
ReshapedDistribution,
)
from gpjax.integrators import (
AbstractIntegrator,
AnalyticalGaussianIntegrator,
Expand Down
1 change: 0 additions & 1 deletion gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ScalarFloat,
)


tfd = tfp.distributions

import cola
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jax import config
from jaxtyping import install_import_hook

config.update("jax_enable_x64", True)

# import gpjax within import hook to apply beartype everywhere, before running tests
Expand Down

0 comments on commit 533347f

Please sign in to comment.