diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 47ba74a30..0fa71556d 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -29,7 +29,10 @@ ) from gpjax.citation import cite from gpjax.dataset import Dataset -from gpjax.fit import fit +from gpjax.fit import ( + fit, + fit_scipy, +) __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" @@ -52,4 +55,5 @@ "fit", "Module", "param_field", + "fit_scipy", ]