-
Notifications
You must be signed in to change notification settings - Fork 73
/
baselaplace.py
2980 lines (2590 loc) · 113 KB
/
baselaplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import warnings
from collections.abc import MutableMapping
from importlib.util import find_spec
from math import log, pi, sqrt
from typing import Any, Callable
import numpy as np
import torch
import torchmetrics
import tqdm
from torch import nn
from torch.linalg import LinAlgError
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.utils.data import DataLoader
from laplace.curvature.asdfghjkl import AsdfghjklHessian
from laplace.curvature.asdl import AsdlGGN
from laplace.curvature.backpack import BackPackGGN
from laplace.curvature.curvature import CurvatureInterface
from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN
from laplace.utils import SoDSampler
from laplace.utils.enums import (
Likelihood,
LinkApprox,
PredType,
PriorStructure,
TuningMethod,
)
from laplace.utils.matrix import Kron, KronDecomposed
from laplace.utils.metrics import RunningNLLMetric
from laplace.utils.utils import (
fix_prior_prec_structure,
invsqrt_precision,
normal_samples,
validate,
)
__all__ = [
"BaseLaplace",
"ParametricLaplace",
"FunctionalLaplace",
"FullLaplace",
"KronLaplace",
"DiagLaplace",
"LowRankLaplace",
]
class BaseLaplace:
"""Baseclass for all Laplace approximations in this library.
Parameters
----------
model : torch.nn.Module
likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
determines the log likelihood Hessian approximation.
In the case of 'reward_modeling', it fits Laplace using the classification likelihood,
then does prediction as in regression likelihood. The model needs to be defined accordingly:
The forward pass during training takes `x.shape == (batch_size, 2, dim)` with
`y.shape = (batch_size,)`. Meanwhile, during evaluation `x.shape == (batch_size, dim)`.
Note that 'reward_modeling' only supports `KronLaplace` and `DiagLaplace`.
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay);
can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more
concentrated posterior and vice versa.
enable_backprop: bool, default=False
whether to enable backprop to the input `x` through the Laplace predictive.
Useful for e.g. Bayesian optimization.
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
dict_key_y: str, default='labels'
The dictionary key under which the target tensor `y` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
backend : subclasses of `laplace.curvature.CurvatureInterface`
backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
asdl_fisher_kwargs : dict, default=None
arguments passed to the ASDL backend specifically on initialization.
"""
def __init__(
self,
model: nn.Module,
likelihood: Likelihood | str,
sigma_noise: float | torch.Tensor = 1.0,
prior_precision: float | torch.Tensor = 1.0,
prior_mean: float | torch.Tensor = 0.0,
temperature: float = 1.0,
enable_backprop: bool = False,
dict_key_x: str = "input_ids",
dict_key_y: str = "labels",
backend: type[CurvatureInterface] | None = None,
backend_kwargs: dict[str, Any] | None = None,
asdl_fisher_kwargs: dict[str, Any] | None = None,
) -> None:
if likelihood not in [lik.value for lik in Likelihood]:
raise ValueError(f"Invalid likelihood type {likelihood}")
self.model: nn.Module = model
self.likelihood: Likelihood | str = likelihood
# Only do Laplace on params that require grad
self.params: list[torch.Tensor] = []
self.is_subset_params: bool = False
for p in model.parameters():
if p.requires_grad:
self.params.append(p)
else:
self.is_subset_params = True
self.n_params: int = sum(p.numel() for p in self.params)
self.n_layers: int = len(self.params)
self.prior_precision: float | torch.Tensor = prior_precision
self.prior_mean: float | torch.Tensor = prior_mean
if sigma_noise != 1 and likelihood != Likelihood.REGRESSION:
raise ValueError("Sigma noise != 1 only available for regression.")
self.sigma_noise: float | torch.Tensor = sigma_noise
self.temperature: float = temperature
self.enable_backprop: bool = enable_backprop
# For models with dict-like inputs (e.g. Huggingface LLMs)
self.dict_key_x = dict_key_x
self.dict_key_y = dict_key_y
if backend is None:
backend = CurvlinopsGGN
else:
if self.is_subset_params and (
"backpack" in backend.__name__.lower()
or "asdfghjkl" in backend.__name__.lower()
):
raise ValueError(
"If some grad are switched off, the BackPACK and Asdfghjkl backends"
" are not supported."
)
self._backend: CurvatureInterface | None = None
self._backend_cls: type[CurvatureInterface] = backend
self._backend_kwargs: dict[str, Any] = (
dict() if backend_kwargs is None else backend_kwargs
)
self._asdl_fisher_kwargs: dict[str, Any] = (
dict() if asdl_fisher_kwargs is None else asdl_fisher_kwargs
)
# log likelihood = g(loss)
self.loss: float = 0.0
self.n_outputs: int = 0
self.n_data: int = 0
# Declare attributes
self._prior_mean: torch.Tensor
self._prior_precision: torch.Tensor
self._sigma_noise: torch.Tensor
self._posterior_scale: torch.Tensor | None
@property
def _device(self) -> torch.device:
return next(self.model.parameters()).device
@property
def backend(self) -> CurvatureInterface:
if self._backend is None:
likelihood = (
"classification"
if self.likelihood == "reward_modeling"
else self.likelihood
)
self._backend = self._backend_cls(
self.model,
likelihood,
dict_key_x=self.dict_key_x,
dict_key_y=self.dict_key_y,
**self._backend_kwargs,
)
return self._backend
def _curv_closure(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
N: int,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def fit(self, train_loader: DataLoader) -> None:
raise NotImplementedError
def log_marginal_likelihood(
self,
prior_precision: torch.Tensor | None = None,
sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
@property
def log_likelihood(self) -> torch.Tensor:
"""Compute log likelihood on the training data after `.fit()` has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.
Returns
-------
log_likelihood : torch.Tensor
"""
factor = -self._H_factor
if self.likelihood == "regression":
# loss used is just MSE, need to add normalizer for gaussian likelihood
c = (
self.n_data
* self.n_outputs
* torch.log(torch.as_tensor(self.sigma_noise) * sqrt(2 * pi))
)
return factor * self.loss - c
else:
# for classification Xent == log Cat
return factor * self.loss
def __call__(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str,
link_approx: LinkApprox | str,
n_samples: int,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def predictive(
self,
x: torch.Tensor,
pred_type: PredType | str,
link_approx: LinkApprox | str,
n_samples: int,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self(x, pred_type, link_approx, n_samples)
def _check_jacobians(self, Js: torch.Tensor) -> None:
if not isinstance(Js, torch.Tensor):
raise ValueError("Jacobians have to be torch.Tensor.")
if not Js.device == self._device:
raise ValueError("Jacobians need to be on the same device as Laplace.")
m, k, p = Js.size()
if p != self.n_params:
raise ValueError("Invalid Jacobians shape for Laplace posterior approx.")
@property
def prior_precision_diag(self) -> torch.Tensor:
"""Obtain the diagonal prior precision \\(p_0\\) constructed from either
a scalar, layer-wise, or diagonal prior precision.
Returns
-------
prior_precision_diag : torch.Tensor
"""
prior_prec: torch.Tensor = (
self.prior_precision
if isinstance(self.prior_precision, torch.Tensor)
else torch.tensor(self.prior_precision)
)
if prior_prec.ndim == 0 or len(prior_prec) == 1: # scalar
return self.prior_precision * torch.ones(self.n_params, device=self._device)
elif len(prior_prec) == self.n_params: # diagonal
return prior_prec
elif len(prior_prec) == self.n_layers: # per layer
n_params_per_layer = [p.numel() for p in self.params]
return torch.cat(
[
prior * torch.ones(n_params, device=self._device)
for prior, n_params in zip(prior_prec, n_params_per_layer)
]
)
else:
raise ValueError(
"Mismatch of prior and model. Diagonal, scalar, or per-layer prior."
)
@property
def prior_mean(self) -> torch.Tensor:
return self._prior_mean
@prior_mean.setter
def prior_mean(self, prior_mean: float | torch.Tensor) -> None:
if np.isscalar(prior_mean) and np.isreal(prior_mean):
self._prior_mean = torch.tensor(prior_mean, device=self._device)
elif isinstance(prior_mean, torch.Tensor):
if prior_mean.ndim == 0:
self._prior_mean = prior_mean.reshape(-1).to(self._device)
elif prior_mean.ndim == 1:
if len(prior_mean) not in [1, self.n_params]:
raise ValueError("Invalid length of prior mean.")
self._prior_mean = prior_mean
else:
raise ValueError("Prior mean has too many dimensions!")
else:
raise ValueError("Invalid argument type of prior mean.")
@property
def prior_precision(self) -> torch.Tensor:
return self._prior_precision
@prior_precision.setter
def prior_precision(self, prior_precision: float | torch.Tensor):
self._posterior_scale = None
if np.isscalar(prior_precision) and np.isreal(prior_precision):
self._prior_precision = torch.tensor([prior_precision], device=self._device)
elif isinstance(prior_precision, torch.Tensor):
if prior_precision.ndim == 0:
# make dimensional
self._prior_precision = prior_precision.reshape(-1).to(self._device)
elif prior_precision.ndim == 1:
if len(prior_precision) not in [1, self.n_layers, self.n_params]:
raise ValueError(
"Length of prior precision does not align with architecture."
)
self._prior_precision = prior_precision.to(self._device)
else:
raise ValueError(
"Prior precision needs to be at most one-dimensional tensor."
)
else:
raise ValueError(
"Prior precision either scalar or torch.Tensor up to 1-dim."
)
def optimize_prior_precision(
self,
pred_type: PredType | str,
method: TuningMethod | str = TuningMethod.MARGLIK,
n_steps: int = 100,
lr: float = 1e-1,
init_prior_prec: float | torch.Tensor = 1.0,
prior_structure: PriorStructure | str = PriorStructure.DIAG,
val_loader: DataLoader | None = None,
loss: torchmetrics.Metric
| Callable[[torch.Tensor], torch.Tensor | float]
| None = None,
log_prior_prec_min: float = -4,
log_prior_prec_max: float = 4,
grid_size: int = 100,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
verbose: bool = False,
progress_bar: bool = False,
) -> None:
"""Optimize the prior precision post-hoc using the `method`
specified by the user.
Parameters
----------
pred_type : PredType or str in {'glm', 'nn'}
type of posterior predictive, linearized GLM predictive or neural
network sampling predictiv. The GLM predictive is consistent with the
curvature approximations used here.
method : TuningMethod or str in {'marglik', 'gridsearch'}, default=PredType.MARGLIK
specifies how the prior precision should be optimized.
n_steps : int, default=100
the number of gradient descent steps to take.
lr : float, default=1e-1
the learning rate to use for gradient descent.
init_prior_prec : float or tensor, default=1.0
initial prior precision before the first optimization step.
prior_structure : PriorStructure or str in {'scalar', 'layerwise', 'diag'}, default=PriorStructure.SCALAR
if init_prior_prec is scalar, the prior precision is optimized with this structure.
otherwise, the structure of init_prior_prec is maintained.
val_loader : torch.data.utils.DataLoader, default=None
DataLoader for the validation set; each iterate is a training batch (X, y).
loss : callable or torchmetrics.Metric, default=None
loss function to use for CV. If callable, the loss is computed offline (memory intensive).
If torchmetrics.Metric, running loss is computed (efficient). The default
depends on the likelihood: `RunningNLLMetric()` for classification and
reward modeling, running `MeanSquaredError()` for regression.
log_prior_prec_min : float, default=-4
lower bound of gridsearch interval.
log_prior_prec_max : float, default=4
upper bound of gridsearch interval.
grid_size : int, default=100
number of values to consider inside the gridsearch interval.
link_approx : LinkApprox or str in {'mc', 'probit', 'bridge'}, default=LinkApprox.PROBIT
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only `'mc'` is possible.
n_samples : int, default=100
number of samples for `link_approx='mc'`.
verbose : bool, default=False
if true, the optimized prior precision will be printed
(can be a large tensor if the prior has a diagonal covariance).
progress_bar : bool, default=False
whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
"""
likelihood = (
Likelihood.CLASSIFICATION
if self.likelihood == Likelihood.REWARD_MODELING
else self.likelihood
)
if likelihood == Likelihood.CLASSIFICATION:
warnings.warn(
"By default `link_approx` is `probit`. Make sure to set it equals to "
"the way you want to call `la(test_data, pred_type=..., link_approx=...)`."
)
if method == TuningMethod.MARGLIK:
if val_loader is not None:
warnings.warn(
"`val_loader` will be ignored when `method` == 'marglik'. "
"Do you mean to set `method = 'gridsearch'`?"
)
self.prior_precision = (
init_prior_prec
if isinstance(init_prior_prec, torch.Tensor)
else torch.tensor(init_prior_prec)
)
if (
len(self.prior_precision) == 1
and prior_structure != PriorStructure.SCALAR
):
self.prior_precision = fix_prior_prec_structure(
self.prior_precision.item(),
prior_structure,
self.n_layers,
self.n_params,
self._device,
)
log_prior_prec = self.prior_precision.log()
log_prior_prec.requires_grad = True
optimizer = torch.optim.Adam([log_prior_prec], lr=lr)
if progress_bar:
pbar = tqdm.trange(n_steps)
pbar.set_description("[Optimizing marginal likelihood]")
else:
pbar = range(n_steps)
for _ in pbar:
optimizer.zero_grad()
prior_prec = log_prior_prec.exp()
neg_log_marglik = -self.log_marginal_likelihood(
prior_precision=prior_prec
)
neg_log_marglik.backward()
optimizer.step()
self.prior_precision = log_prior_prec.detach().exp()
elif method == TuningMethod.GRIDSEARCH:
if val_loader is None:
raise ValueError("gridsearch requires a validation set DataLoader")
interval = torch.logspace(log_prior_prec_min, log_prior_prec_max, grid_size)
if loss is None:
loss = (
torchmetrics.MeanSquaredError(num_outputs=self.n_outputs).to(
self._device
)
if likelihood == Likelihood.REGRESSION
else RunningNLLMetric().to(self._device)
)
self.prior_precision = self._gridsearch(
loss,
interval,
val_loader,
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
progress_bar=progress_bar,
)
else:
raise ValueError("For now only marglik and gridsearch is implemented.")
if verbose:
print(f"Optimized prior precision is {self.prior_precision}.")
def _gridsearch(
self,
loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float],
interval: torch.Tensor,
val_loader: DataLoader,
pred_type: PredType | str,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
progress_bar: bool = False,
) -> torch.Tensor:
assert callable(loss) or isinstance(loss, torchmetrics.Metric)
results: list[float] = list()
prior_precs: list[torch.Tensor] = list()
pbar = tqdm.tqdm(interval, disable=not progress_bar)
for prior_prec in pbar:
self.prior_precision = prior_prec
try:
result = validate(
self,
val_loader,
loss,
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
dict_key_y=self.dict_key_y,
)
except LinAlgError:
result = np.inf
except RuntimeError as err:
if "not positive definite" in str(err):
result = np.inf
else:
raise err
if progress_bar:
pbar.set_description(
f"[Grid search | prior_prec: {prior_prec:.3e}, loss: {result:.3f}]"
)
results.append(result)
prior_precs.append(prior_prec)
return prior_precs[np.argmin(results)]
@property
def sigma_noise(self) -> torch.Tensor:
return self._sigma_noise
@sigma_noise.setter
def sigma_noise(self, sigma_noise: float | torch.Tensor) -> None:
self._posterior_scale = None
if np.isscalar(sigma_noise) and np.isreal(sigma_noise):
self._sigma_noise = torch.tensor(sigma_noise, device=self._device)
elif isinstance(sigma_noise, torch.Tensor):
if sigma_noise.ndim == 0:
self._sigma_noise = sigma_noise.to(self._device)
elif sigma_noise.ndim == 1:
if len(sigma_noise) > 1:
raise ValueError("Only homoscedastic output noise supported.")
self._sigma_noise = sigma_noise[0].to(self._device)
else:
raise ValueError("Sigma noise needs to be scalar or 1-dimensional.")
else:
raise ValueError(
"Invalid type: sigma noise needs to be torch.Tensor or scalar."
)
@property
def _H_factor(self) -> torch.Tensor:
sigma2 = self.sigma_noise.square()
return 1 / sigma2 / self.temperature
def _glm_forward_call(
self,
x: torch.Tensor | MutableMapping,
likelihood: Likelihood | str,
joint: bool = False,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
diagonal_output: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Compute the posterior predictive on input data `x` for "glm" pred type.
Parameters
----------
x : torch.Tensor or MutableMapping
`(batch_size, input_shape)` if tensor. If MutableMapping, must contain
the said tensor.
likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'}
determines the log likelihood Hessian approximation.
link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only 'mc' is possible.
joint : bool
Whether to output a joint predictive distribution in regression with
`pred_type='glm'`. If set to `True`, the predictive distribution
has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
If `False`, then only outputs the marginal predictive distribution.
Only available for regression and GLM predictive.
n_samples : int
number of samples for `link_approx='mc'`.
diagonal_output : bool
whether to use a diagonalized posterior predictive on the outputs.
Only works for `pred_type='glm'` and `link_approx='mc'`.
Returns
-------
predictive: torch.Tensor or tuple[torch.Tensor]
For `likelihood='classification'`, a torch.Tensor is returned with
a distribution over classes (similar to a Softmax).
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
"""
f_mu, f_var = self._glm_predictive_distribution(
x, joint=joint and likelihood == Likelihood.REGRESSION
)
if likelihood == Likelihood.REGRESSION:
if diagonal_output and not joint:
f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
return f_mu, f_var
if link_approx == LinkApprox.MC:
return self._glm_predictive_samples(
f_mu,
f_var,
n_samples=n_samples,
diagonal_output=diagonal_output,
).mean(dim=0)
elif link_approx == LinkApprox.PROBIT:
kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
return torch.softmax(kappa * f_mu, dim=-1)
elif "bridge" in link_approx:
# zero mean correction
f_mu -= (
f_var.sum(-1)
* f_mu.sum(-1).reshape(-1, 1)
/ f_var.sum(dim=(1, 2)).reshape(-1, 1)
)
f_var -= torch.einsum(
"bi,bj->bij", f_var.sum(-1), f_var.sum(-2)
) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1)
# Laplace Bridge
_, K = f_mu.size(0), f_mu.size(-1)
f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)
# optional: variance correction
if link_approx == LinkApprox.BRIDGE_NORM:
f_var_diag_mean = f_var_diag.mean(dim=1)
f_var_diag_mean /= torch.as_tensor([K / 2], device=self._device).sqrt()
f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
f_var_diag /= f_var_diag_mean.unsqueeze(-1)
sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
else:
raise ValueError(
"Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
)
def _glm_predictive_samples(
self,
f_mu: torch.Tensor,
f_var: torch.Tensor,
n_samples: int,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Sample from the posterior predictive on input data `x` using "glm" prediction
type.
Parameters
----------
f_mu : torch.Tensor or MutableMapping
glm predictive mean `(batch_size, output_shape)`
f_var : torch.Tensor or MutableMapping
glm predictive covariances `(batch_size, output_shape, output_shape)`
n_samples : int
number of samples
diagonal_output : bool
whether to use a diagonalized glm posterior predictive on the outputs.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used)
Returns
-------
samples : torch.Tensor
samples `(n_samples, batch_size, output_shape)`
"""
assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])
if diagonal_output:
f_var = torch.diagonal(f_var, dim1=1, dim2=2)
f_samples = normal_samples(f_mu, f_var, n_samples, generator)
if self.likelihood == Likelihood.REGRESSION:
return f_samples
else:
return torch.softmax(f_samples, dim=-1)
class ParametricLaplace(BaseLaplace):
"""
Parametric Laplace class.
Subclasses need to specify how the Hessian approximation is initialized,
how to add up curvature over training data, how to sample from the
Laplace approximation, and how to compute the functional variance.
A Laplace approximation is represented by a MAP which is given by the
`model` parameter and a posterior precision or covariance specifying
a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
The goal of this class is to compute the posterior precision \\(P\\)
which sums as
\\[
P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta)
\\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}.
\\]
Every subclass implements different approximations to the log likelihood Hessians,
for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
a simple form for \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\).
In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied.
"""
def __init__(
self,
model: nn.Module,
likelihood: Likelihood | str,
sigma_noise: float | torch.Tensor = 1.0,
prior_precision: float | torch.Tensor = 1.0,
prior_mean: float | torch.Tensor = 0.0,
temperature: float = 1.0,
enable_backprop: bool = False,
dict_key_x: str = "input_ids",
dict_key_y: str = "labels",
backend: type[CurvatureInterface] | None = None,
backend_kwargs: dict[str, Any] | None = None,
asdl_fisher_kwargs: dict[str, Any] | None = None,
):
super().__init__(
model,
likelihood,
sigma_noise,
prior_precision,
prior_mean,
temperature,
enable_backprop,
dict_key_x,
dict_key_y,
backend,
backend_kwargs,
asdl_fisher_kwargs,
)
if not hasattr(self, "H"):
self._init_H()
# posterior mean/mode
self.mean: float | torch.Tensor = self.prior_mean
def _init_H(self) -> None:
raise NotImplementedError
def _check_H_init(self) -> None:
if getattr(self, "H", None) is None:
raise AttributeError("Laplace not fitted. Run fit() first.")
def fit(
self,
train_loader: DataLoader,
override: bool = True,
progress_bar: bool = False,
) -> None:
"""Fit the local Laplace approximation at the parameters of the model.
Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch, either `(X, y)` tensors or a dict-like
object containing keys as expressed by `self.dict_key_x` and
`self.dict_key_y`. `train_loader.dataset` needs to be set to access
\\(N\\), size of the data set.
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for
online learning settings to accumulate a sequential posterior approximation.
progress_bar : bool, default=False
whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
"""
if override:
self._init_H()
self.loss: float | torch.Tensor = 0
self.n_data: int = 0
self.model.eval()
self.mean: torch.Tensor = parameters_to_vector(self.params)
if not self.enable_backprop:
self.mean = self.mean.detach()
data: (
tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
) = next(iter(train_loader))
with torch.no_grad():
if isinstance(data, MutableMapping): # To support Huggingface dataset
if "backpack" in self._backend_cls.__name__.lower() or (
isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
):
raise ValueError(
"Currently DiagEF is not supported under CurvlinopsEF backend "
+ "for custom models with non-tensor inputs "
+ "(https://github.com/pytorch/functorch/issues/159). Consider "
+ "using AsdlEF backend instead. The same limitation applies "
+ "to all BackPACK backend"
)
out = self.model(data)
else:
X = data[0]
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))
self.n_outputs = out.shape[-1]
setattr(self.model, "output_size", self.n_outputs)
N = len(train_loader.dataset)
pbar = tqdm.tqdm(train_loader, disable=not progress_bar)
pbar.set_description("[Computing Hessian]")
for data in pbar:
if isinstance(data, MutableMapping): # To support Huggingface dataset
X, y = data, data[self.dict_key_y].to(self._device)
else:
X, y = data
X, y = X.to(self._device), y.to(self._device)
if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)
self.model.zero_grad()
loss_batch, H_batch = self._curv_closure(X, y, N=N)
self.loss += loss_batch
self.H += H_batch
self.n_data += N
@property
def scatter(self) -> torch.Tensor:
"""Computes the _scatter_, a term of the log marginal likelihood that
corresponds to L-2 regularization:
`scatter` = \\((\\theta_{MAP} - \\mu_0)^{T} P_0 (\\theta_{MAP} - \\mu_0) \\).
Returns
-------
scatter: torch.Tensor
"""
delta = self.mean - self.prior_mean
return (delta * self.prior_precision_diag) @ delta
@property
def log_det_prior_precision(self) -> torch.Tensor:
"""Compute log determinant of the prior precision
\\(\\log \\det P_0\\)
Returns
-------
log_det : torch.Tensor
"""
return self.prior_precision_diag.log().sum()
@property
def log_det_posterior_precision(self) -> torch.Tensor:
"""Compute log determinant of the posterior precision
\\(\\log \\det P\\) which depends on the subclasses structure
used for the Hessian approximation.
Returns
-------
log_det : torch.Tensor
"""
raise NotImplementedError
@property
def log_det_ratio(self) -> torch.Tensor:
"""Compute the log determinant ratio, a part of the log marginal likelihood.
\\[
\\log \\frac{\\det P}{\\det P_0} = \\log \\det P - \\log \\det P_0
\\]
Returns
-------
log_det_ratio : torch.Tensor
"""
return self.log_det_posterior_precision - self.log_det_prior_precision
def square_norm(self, value) -> torch.Tensor:
"""Compute the square norm under post. Precision with `value-self.mean` as 𝛥:
\\[
\\Delta^\top P \\Delta
\\]
Returns
-------
square_form
"""
raise NotImplementedError
def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor:
"""Compute the log probability under the (current) Laplace approximation.
Parameters
----------
value: torch.Tensor
normalized : bool, default=True
whether to return log of a properly normalized Gaussian or just the
terms that depend on `value`.
Returns
-------
log_prob : torch.Tensor
"""
if not normalized:
return -self.square_norm(value) / 2
log_prob = (
-self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
)
log_prob -= self.square_norm(value) / 2
return log_prob
def log_marginal_likelihood(
self,
prior_precision: torch.Tensor | None = None,
sigma_noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute the Laplace approximation to the log marginal likelihood subject
to specific Hessian approximations that subclasses implement.
Requires that the Laplace approximation has been fit before.
The resulting torch.Tensor is differentiable in `prior_precision` and
`sigma_noise` if these have gradients enabled.
By passing `prior_precision` or `sigma_noise`, the current value is
overwritten. This is useful for iterating on the log marginal likelihood.
Parameters
----------
prior_precision : torch.Tensor, optional
prior precision if should be changed from current `prior_precision` value
sigma_noise : torch.Tensor, optional
observation noise standard deviation if should be changed
Returns
-------
log_marglik : torch.Tensor
"""
# update prior precision (useful when iterating on marglik)
if prior_precision is not None:
self.prior_precision = prior_precision
# update sigma_noise (useful when iterating on marglik)
if sigma_noise is not None:
if self.likelihood != Likelihood.REGRESSION:
raise ValueError("Can only change sigma_noise for regression.")
self.sigma_noise = sigma_noise
return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)
def __call__(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str = PredType.GLM,
joint: bool = False,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
fitting: bool = False,
**model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Compute the posterior predictive on input data `x`.
Parameters
----------
x : torch.Tensor or MutableMapping
`(batch_size, input_shape)` if tensor. If MutableMapping, must contain
the said tensor.