Skip to content

Commit

Permalink
Updated classic_rffs to use register_buffers in place of Parameters f…
Browse files Browse the repository at this point in the history
…or precision and covariance matrices; this avoids issues with initialization by other model classes.
  • Loading branch information
jlparkI committed Oct 27, 2023
1 parent 11a8a79 commit df99b90
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions uncertaintyAwareDeepLearn/classic_rffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ def __init__(self, in_features: int, RFFs: int, out_targets: int=1,
self.num_freqs = int(0.5 * RFFs)
self.feature_scale = math.sqrt(2. / float(self.num_freqs))

self.register_buffer("weight_mat", torch.empty((in_features, self.num_freqs), **factory_kwargs))
self.register_buffer("weight_mat", torch.zeros((in_features, self.num_freqs), **factory_kwargs))
self.output_weights = Parameter(torch.empty((RFFs, out_targets), **factory_kwargs))
self.covariance = Parameter((1 / self.ridge_penalty) * torch.eye(RFFs),
requires_grad=False)
self.precision_initial = torch.zeros((RFFs, RFFs), requires_grad=False)
self.precision = Parameter(self.precision_initial, requires_grad=False)
self.register_buffer("covariance", torch.zeros((RFFs, RFFs), **factory_kwargs))
self.register_buffer("precision", torch.zeros((RFFs, RFFs), **factory_kwargs))
self.reset_parameters()


Expand All @@ -122,23 +120,24 @@ def reset_parameters(self) -> None:
normal -- in fact, that would set the variance on our sqexp kernel
to something other than 1 (which is ok, but might be unexpected for
the user)."""
self.fitted = False
with torch.no_grad():
rgen = torch.Generator()
rgen.manual_seed(self.random_seed)
self.weight_mat = torch.randn(generator = rgen,
size = self.weight_mat.size())
self.output_weights[:] = torch.randn(generator = rgen,
size = self.output_weights.size())
self.precision[...] = self.precision_initial.detach()
self.covariance[:] = (1 / self.ridge_penalty) * torch.eye(self.RFFs)
self.precision[:] = 0.


def reset_covariance(self) -> None:
"""Resets the covariance to the initial values. Useful if
planning to generate the precision & covariance matrices
on the final epoch."""
self.fitted = False
with torch.no_grad():
self.precision[...] = self.precision_initial.detach()
self.precision[:] = 0.
self.covariance[:] = (1 / self.ridge_penalty) * torch.eye(self.RFFs)

Expand Down Expand Up @@ -172,6 +171,7 @@ def forward(self, input_tensor: Tensor, update_precision: bool = False,
logits = rff_mat @ self.output_weights

if update_precision:
self.fitted = False
self._update_precision(rff_mat, logits)

if get_var:
Expand Down

0 comments on commit df99b90

Please sign in to comment.