diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml
index 1c49f7f..2cba2f8 100644
--- a/.github/workflows/pytest.yml
+++ b/.github/workflows/pytest.yml
@@ -19,7 +19,7 @@ jobs:
# You can test your matrix by printing the current Python version
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ python -m pip install --upgrade pip wheel packaging
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .
- name: Test with pytest
diff --git a/README.md b/README.md
index baab30f..af80878 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,10 @@
-
-
+
-[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)
+![pytest](https://github.com/aleximmer/laplace/actions/workflows/pytest.yml/badge.svg)
+![lint](https://github.com/aleximmer/laplace/actions/workflows/lint-ruff.yml/badge.svg)
+![format](https://github.com/aleximmer/laplace/actions/workflows/format-ruff.yml/badge.svg)
+
The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
@@ -45,6 +47,13 @@ pytest tests/
## Example usage
+> [!IMPORTANT]
+> As a user, one should not expect Laplace to work automatically.
+> That is, one should experiment with different Laplace's options
+> (hessian_factorization, prior precision tuning method, predictive method, backend,
+> etc!). Try looking at various papers that use Laplace for references on how to
+> set all those options depending on the applications/problems at hand.
+
### _Post-hoc_ prior precision tuning of diagonal LA
In the following example, a pre-trained model is loaded,
@@ -283,6 +292,11 @@ trained on a GPU but want to run predictions on CPU. In this case, use
torch.load(..., map_location="cpu")
```
+> [!WARNING]
+> Currently, this library always assumes that the model has an
+> output tensor of shape `(batch_size, ..., n_classes)`, so in
+> the case of image outputs, you need to rearrange from NCHW to NHWC.
+
## Structure
The laplace package consists of two main components:
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index da956d0..de246b8 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -1021,18 +1021,23 @@ def _glm_predictive_distribution(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
joint: bool = False,
- diagonal_output=False,
+ diagonal_output: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
- backend_name = self._backend_cls.__name__.lower()
- if self.enable_backprop and (
- "curvlinops" not in backend_name and "backpack" not in backend_name
- ):
- raise ValueError(
- "Backprop through the GLM predictive is only available for the "
- "Curvlinops and BackPACK backends."
+ if "asdl" in self._backend_cls.__name__.lower():
+ # Asdl's doesn't support backprop over Jacobians
+ # falling back to functorch
+ warnings.warn(
+ "ASDL backend is used which does not support backprop through "
+ "the functional variance, but `self.enable_backprop = True`. "
+ "Falling back to using `self.backend.functorch_jacobians` "
+ "which can be memory intensive for large models."
)
- Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop)
+ Js, f_mu = self.backend.functorch_jacobians(
+ X, enable_backprop=self.enable_backprop
+ )
+ else:
+ Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop)
if joint:
f_mu = f_mu.flatten() # (batch*out)
diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py
index 7b5fe60..c0397f1 100644
--- a/laplace/curvature/curvature.py
+++ b/laplace/curvature/curvature.py
@@ -287,6 +287,8 @@ def diag(
"""
raise NotImplementedError
+ functorch_jacobians = jacobians
+
class GGNInterface(CurvatureInterface):
"""Generalized Gauss-Newton or Fisher Curvature Interface.
diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py
index d7393fb..5015a61 100644
--- a/tests/test_baselaplace.py
+++ b/tests/test_baselaplace.py
@@ -662,7 +662,9 @@ def test_dict_data(laplace, backend, lik, custom_loader, custom_model, request):
@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
-@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
+@pytest.mark.parametrize(
+ "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
+)
def test_backprop_glm(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
@@ -682,7 +684,9 @@ def test_backprop_glm(laplace, model, reg_loader, backend):
@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
-@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
+@pytest.mark.parametrize(
+ "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
+)
def test_backprop_glm_joint(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
@@ -702,7 +706,9 @@ def test_backprop_glm_joint(laplace, model, reg_loader, backend):
@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
-@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
+@pytest.mark.parametrize(
+ "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
+)
def test_backprop_glm_mc(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
@@ -722,7 +728,9 @@ def test_backprop_glm_mc(laplace, model, reg_loader, backend):
@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
-@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
+@pytest.mark.parametrize(
+ "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
+)
def test_backprop_nn(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True