You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It appears that the new release of JAX v0.4.36 breaks cola.linalg.solve with Cholesky solver. It results in Nan gradients.
k_zz_inv_U = cola.linalg.solve(K, self.U.reshape(self.M, -1), alg=Cholesky())
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/inverse/inv.py", line 39, in solve
return inv(A, alg) @ b
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/dispatcher.py", line 93, in new_fn
return fn(*new_args)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/function.py", line 444, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/inverse/inv.py", line 98, in inv
L = cholesky(A)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/function.py", line 444, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/decompositions/decompositions.py", line 154, in cholesky
return Triangular(A.xnp.cholesky(A.to_dense()), lower=True)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/ops/operator_base.py", line 77, in to_dense
return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/backends/jax_fns.py", line 154, in eye
return jnp.eye(N=n, M=m, dtype=dtype)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6636, in eye
output = _eye(N, M=M, k=k, dtype=dtype)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6659, in _eye
return (i + offset == j).astype(dtype)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 1049, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 573, in deferring_binary_op
return binary_op(*args)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py", line 179, in __call__
return call(*args)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Cannot lower jaxpr with verifier errors:
type of return operand 0 ('tensor<50x50xi64>') doesn't match function result type ('tensor<50x50xi32>') in function @main
at loc(unknown)
see current operation: "func.return"(%1) : (tensor<50x50xi64>) -> ()
at loc(unknown)
Define JAX_DUMP_IR_TO to dump the module.
The code that generated this is a very simple sparse GP using inducing points, which is well tested and works fine on versions on JAX 0.4.35 and below.
@adam-hartshorne, thanks for bringing this discrepancy introcuded in JAX v0.4.36 to our attention. I'll do some analysis on the Cholesky function on JAX and report back.
On the meantime, could you share more code to replicate the exact error above?
I figured out that solve isn't causing this behaviour,
if cola.ops.I_like(k_zz)) is changed to jnp.eye(k_zz.shape[0], dtype=self.Z.dtype)) fixes the NaN problem. So it isn't actually the solve, it is something to do with I_like.
It appears that the new release of JAX v0.4.36 breaks cola.linalg.solve with Cholesky solver. It results in Nan gradients.
The code that generated this is a very simple sparse GP using inducing points, which is well tested and works fine on versions on JAX 0.4.35 and below.
The text was updated successfully, but these errors were encountered: