Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Jax v0.4.36 Breaks Cola Solve #112

Open
adam-hartshorne opened this issue Dec 6, 2024 · 2 comments
Open

[Bug] Jax v0.4.36 Breaks Cola Solve #112

adam-hartshorne opened this issue Dec 6, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

It appears that the new release of JAX v0.4.36 breaks cola.linalg.solve with Cholesky solver. It results in Nan gradients.

  k_zz_inv_U = cola.linalg.solve(K, self.U.reshape(self.M, -1), alg=Cholesky())
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/inverse/inv.py", line 39, in solve
    return inv(A, alg) @ b
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/dispatcher.py", line 93, in new_fn
    return fn(*new_args)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/function.py", line 444, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/inverse/inv.py", line 98, in inv
    L = cholesky(A)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/plum/function.py", line 444, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/linalg/decompositions/decompositions.py", line 154, in cholesky
    return Triangular(A.xnp.cholesky(A.to_dense()), lower=True)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/ops/operator_base.py", line 77, in to_dense
    return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/cola/backends/jax_fns.py", line 154, in eye
    return jnp.eye(N=n, M=m, dtype=dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6636, in eye
    output = _eye(N, M=M, k=k, dtype=dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6659, in _eye
    return (i + offset == j).astype(dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 1049, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 573, in deferring_binary_op
    return binary_op(*args)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py", line 179, in __call__
    return call(*args)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Cannot lower jaxpr with verifier errors:
	type of return operand 0 ('tensor<50x50xi64>') doesn't match function result type ('tensor<50x50xi32>') in function @main
		at loc(unknown)
	see current operation: "func.return"(%1) : (tensor<50x50xi64>) -> ()
		at loc(unknown)
Define JAX_DUMP_IR_TO to dump the module.

The code that generated this is a very simple sparse GP using inducing points, which is well tested and works fine on versions on JAX 0.4.35 and below.

def gp_fit(self, t):
    kern = partial(matern52_kernel, length_scale=self.length_scale, variance=self.variance)
    k_zz = cola.ops.Dense(kern(self.Z, self.Z))
    k_xz = cola.ops.Dense(kern(t, self.Z))
    K = cola.PSD(k_zz + ((self.noise + self.jitter) * cola.ops.I_like(k_zz)))
    k_zz_inv_U = cola.linalg.solve(K, self.U.reshape(self.M, -1), alg=Cholesky())
    return k_xz @ k_zz_inv_U

@adam-hartshorne adam-hartshorne added the bug Something isn't working label Dec 6, 2024
@AndPotap
Copy link
Collaborator

AndPotap commented Dec 9, 2024

@adam-hartshorne, thanks for bringing this discrepancy introcuded in JAX v0.4.36 to our attention. I'll do some analysis on the Cholesky function on JAX and report back.

On the meantime, could you share more code to replicate the exact error above?

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Dec 11, 2024

I figured out that solve isn't causing this behaviour,

if cola.ops.I_like(k_zz)) is changed to jnp.eye(k_zz.shape[0], dtype=self.Z.dtype)) fixes the NaN problem. So it isn't actually the solve, it is something to do with I_like.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants