Skip to content

Commit

Permalink
Merge pull request #2388 from douglas-boubert/grad_kernels_only_creat…
Browse files Browse the repository at this point in the history
…e_covariance_if_needed

Stop rbf_kernel_grad and rbf_kernel_gradgrad creating the full covariance matrix unnecessarily
  • Loading branch information
Balandat authored Aug 5, 2023
2 parents 090d6e1 + eed36a4 commit 8979210
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions gpytorch/kernels/rbf_kernel_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def forward(self, x1, x2, diag=False, **params):
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)

if not diag:
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)

# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/kernels/rbf_kernel_gradgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def forward(self, x1, x2, diag=False, **params):
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)

if not diag:
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)

# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)
Expand Down

0 comments on commit 8979210

Please sign in to comment.