Skip to content

Commit

Permalink
Fix pytorch 1.10 warnings in test suite (#1835)
Browse files Browse the repository at this point in the history
* Fix warnings in CG

* Fix linalg deprication warnings

* Fix __floordiv__ deprication warnings

* Fix cholesky jitter value deprication warnigns

* Fix ScaleToBounds deprication error

* Fix meshgrid deprication warnings
  • Loading branch information
gpleiss authored Nov 22, 2021
1 parent 04c6c85 commit 452443a
Show file tree
Hide file tree
Showing 15 changed files with 34 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ requirements:
- python>=3.6

run:
- pytorch>=1.8.1
- pytorch>=1.9
- scikit-learn
- scipy

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
if [[ ${{ matrix.pytorch-version }} = "master" ]]; then
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html;
else
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html;
pip install torch==1.9+cpu -f https://download.pytorch.org/whl/torch_stable.html;
fi
if [[ ${{ matrix.pyro }} == "with-pyro" ]]; then
pip install git+https://github.com/pyro-ppl/pyro@master;
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ See our numerous [**examples and tutorials**](https://gpytorch.readthedocs.io/en

**Requirements**:
- Python >= 3.6
- PyTorch >= 1.8.1
- PyTorch >= 1.9

Install GPyTorch using pip or conda:

Expand Down
4 changes: 2 additions & 2 deletions gpytorch/lazy/block_diag_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _cholesky_solve(self, rhs, upper: bool = False):

def _get_indices(self, row_index, col_index, *batch_indices):
# Figure out what block the row/column indices belong to
row_index_block = row_index // self.base_lazy_tensor.size(-2)
col_index_block = col_index // self.base_lazy_tensor.size(-1)
row_index_block = torch.div(row_index, self.base_lazy_tensor.size(-2), rounding_mode="floor")
col_index_block = torch.div(col_index, self.base_lazy_tensor.size(-1), rounding_mode="floor")

# Find the row/col index within each block
row_index = row_index.fmod(self.base_lazy_tensor.size(-2))
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/lazy/block_interleaved_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def _get_indices(self, row_index, col_index, *batch_indices):
col_index_block = col_index.fmod(self.base_lazy_tensor.size(-3))

# Find the row/col index within each block
row_index = row_index // self.base_lazy_tensor.size(-3)
col_index = col_index // self.base_lazy_tensor.size(-3)
row_index = torch.div(row_index, self.base_lazy_tensor.size(-3), rounding_mode="floor")
col_index = torch.div(col_index, self.base_lazy_tensor.size(-3), rounding_mode="floor")

# If the row/column blocks do not agree, then we have off diagonal elements
# These elements should be zeroed out
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/lazy/kronecker_product_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def _get_indices(self, row_index, col_index, *batch_indices):
row_factor //= sub_row_size
col_factor //= sub_col_size
sub_res = lazy_tensor._get_indices(
(row_index // row_factor).fmod(sub_row_size),
(col_index // col_factor).fmod(sub_col_size),
torch.div(row_index, row_factor, rounding_mode="floor").fmod(sub_row_size),
torch.div(col_index, col_factor, rounding_mode="floor").fmod(sub_col_size),
*batch_indices,
)
res = sub_res if res is None else (sub_res * res)
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def covar_cache(self):
torch.eye(train_factor.size(-1), dtype=train_factor.dtype, device=train_factor.device)
- (train_factor.transpose(-1, -2) @ train_train_covar.inv_matmul(train_factor)) * constant
)
return psd_safe_cholesky(inner_term, jitter=settings.cholesky_jitter.value())
return psd_safe_cholesky(inner_term)

def exact_prediction(self, joint_mean, joint_covar):
# Find the components of the distribution that contain test data
Expand Down
6 changes: 3 additions & 3 deletions gpytorch/test/lazy_tensor_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase):
"root_decomposition": {"rtol": 0.05},
"root_inv_decomposition": {"rtol": 0.05, "atol": 0.02},
"sample": {"rtol": 0.3, "atol": 0.3},
"sqrt_inv_matmul": {"rtol": 1e-4, "atol": 1e-3},
"sqrt_inv_matmul": {"rtol": 1e-2, "atol": 1e-3},
"symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}},
"svd": {"rtol": 1e-4, "atol": 1e-3},
}
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_sqrt_inv_matmul(self):
# Perform forward pass
with gpytorch.settings.max_cg_iterations(200):
sqrt_inv_matmul_res, inv_quad_res = lazy_tensor.sqrt_inv_matmul(rhs, lhs)
evals, evecs = evaluated.symeig(eigenvectors=True)
evals, evecs = torch.linalg.eigh(evaluated)
matrix_inv_root = evecs @ (evals.sqrt().reciprocal().unsqueeze(-1) * evecs.transpose(-1, -2))
sqrt_inv_matmul_actual = lhs_copy @ matrix_inv_root @ rhs_copy
inv_quad_actual = (lhs_copy @ matrix_inv_root).pow(2).sum(dim=-1)
Expand Down Expand Up @@ -744,7 +744,7 @@ def test_sqrt_inv_matmul_no_lhs(self):
# Perform forward pass
with gpytorch.settings.max_cg_iterations(200):
sqrt_inv_matmul_res = lazy_tensor.sqrt_inv_matmul(rhs)
evals, evecs = evaluated.symeig(eigenvectors=True)
evals, evecs = torch.linalg.eigh(evaluated)
matrix_inv_root = evecs @ (evals.sqrt().reciprocal().unsqueeze(-1) * evecs.transpose(-1, -2))
sqrt_inv_matmul_actual = matrix_inv_root @ rhs_copy

Expand Down
8 changes: 5 additions & 3 deletions gpytorch/utils/contour_integral_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sqrt_precond_matmul(rhs):
)
lanczos_mat = lanczos_mat.squeeze(0) # We have an extra singleton batch dimension from the Lanczos init

"""
r"""
K^{-1/2} b = 2/pi \int_0^\infty (K - t^2 I)^{-1} dt
We'll approximate this integral as a sum using quadrature
We'll determine the appropriate values of t, as well as their weights using elliptical integrals
Expand All @@ -81,9 +81,11 @@ def sqrt_precond_matmul(rhs):
# We'll do this with Lanczos
try:
if settings.verbose_linalg.on():
settings.verbose_linalg.logger.debug(f"Running symeig on a matrix of size {lanczos_mat.shape}.")
settings.verbose_linalg.logger.debug(
f"Running torch.linalg.eigvalsh on a matrix of size {lanczos_mat.shape}."
)

approx_eigs = lanczos_mat.symeig()[0]
approx_eigs = torch.linalg.eigvalsh(lanczos_mat)
if approx_eigs.min() <= 0:
raise RuntimeError
except RuntimeError:
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/utils/linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def linear_cg(

# Update tridiagonal matrices, if applicable
if n_tridiag and k < n_tridiag_iter and update_tridiag:
alpha_tridiag = alpha.squeeze_(-2).narrow(-1, 0, n_tridiag)
beta_tridiag = beta.squeeze_(-2).narrow(-1, 0, n_tridiag)
alpha_tridiag = alpha.squeeze(-2).narrow(-1, 0, n_tridiag)
beta_tridiag = beta.squeeze(-2).narrow(-1, 0, n_tridiag)
torch.eq(alpha_tridiag, 0, out=alpha_tridiag_is_zero)
alpha_tridiag.masked_fill_(alpha_tridiag_is_zero, 1)
torch.reciprocal(alpha_tridiag, out=alpha_reciprocal)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch>=1.8.1
torch>=1.9
scikit-learn
scipy
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def find_version(*file_paths):
version = find_version("gpytorch", "__init__.py")


torch_min = "1.8.1"
torch_min = "1.9"
install_requires = [">=".join(["torch", torch_min]), "scikit-learn", "scipy"]
# if recent dev version of PyTorch is installed, no need to install stable
try:
Expand Down
7 changes: 4 additions & 3 deletions test/utils/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ def test_scale_to_bounds(self):
"""
"""
x = torch.randn(100) * 50
res = gpytorch.utils.grid.scale_to_bounds(x, -1, 1)
self.assertGreater(res.min().item(), -1)
self.assertLess(res.max().item(), 1)
scale_module = gpytorch.utils.grid.ScaleToBounds(-1.0, 1.0)
res = scale_module(x)
self.assertGreater(res.min().item(), -1.0)
self.assertLess(res.max().item(), 1.0)

def test_choose_grid_size(self):
"""
Expand Down
12 changes: 6 additions & 6 deletions test/utils/test_linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_cg(self):
solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size)

# Check cg
matrix_chol = matrix.cholesky()
matrix_chol = torch.linalg.cholesky(matrix)
actual = torch.cholesky_solve(rhs, matrix_chol)
self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4))

Expand All @@ -50,14 +50,14 @@ def test_cg_with_tridiag(self):
)

# Check cg
matrix_chol = matrix.cholesky()
matrix_chol = torch.linalg.cholesky(matrix)
actual = torch.cholesky_solve(rhs, matrix_chol)
self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4))

# Check tridiag
eigs = matrix.symeig()[0]
eigs = torch.linalg.eigvalsh(matrix)
for i in range(5):
approx_eigs = t_mats[i].symeig()[0]
approx_eigs = torch.linalg.eigvalsh(t_mats[i])
self.assertTrue(torch.allclose(eigs, approx_eigs, atol=1e-3, rtol=1e-4))

def test_batch_cg(self):
Expand Down Expand Up @@ -96,9 +96,9 @@ def test_batch_cg_with_tridiag(self):

# Check tridiag
for i in range(5):
eigs = matrix[i].symeig()[0]
eigs = torch.linalg.eigvalsh(matrix[i])
for j in range(8):
approx_eigs = t_mats[j, i].symeig()[0]
approx_eigs = torch.linalg.eigvalsh(t_mats[j, i])
self.assertTrue(torch.allclose(eigs, approx_eigs, atol=1e-3, rtol=1e-4))


Expand Down
4 changes: 2 additions & 2 deletions test/variational/test_natural_variational_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_natgrad(self, D=5):
mu = torch.randn(D)
cov = torch.randn(D, D)
cov = cov @ cov.t()
dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov.cholesky())))
dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov))))
sample = dist.sample()

v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
Expand All @@ -132,7 +132,7 @@ def f(natural_vec, natural_tril_mat):
"Transform natural_tril_mat to L"
Sigma = torch.inverse(-2 * natural_tril_mat)
mu = natural_vec
return mu, Sigma.cholesky().inverse().tril()
return mu, torch.linalg.cholesky(Sigma).inverse().tril()

(mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp(
f,
Expand Down

0 comments on commit 452443a

Please sign in to comment.