You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The conversion of a function from casadi to jax is broken when the symbolic result is a matrix with some structural zeros (when printing an expression with structural zeros, these will be represented as 00 to distinguish them from actual zeros 0).
Minimal example:
import casadi as ca
import jax
import jaxadi
import numpy as np
X = ca.SX.sym("x", 2)
A = np.random.randn(3, 2)
A[1, :] = 0.0
Y = ca.jacobian(A @ X, X)
# Create CasADi function
ca_foo = ca.Function("foo", [X], [Y], ["X"], ["Y"])
# Convert CasADi function to JAX
jax_foo = jaxadi.convert(ca_foo, compile=True)
# Let's test for some X
X0 = np.random.randn(2, 1)
print("casadi result:", ca_foo(X0))
print("jax result:\n", jax_foo(X0)[0])
print("casadi symbolic: ", Y)
print("jax translation: ", jaxadi.translate(ca_foo))
A quick fix would be to call casadi.densify() internally on all casadi expressions before converting them so that structural 0s become standard 0s, which are properly handled. But there might be some side effects of densify() that I am not aware of.
It is also not optimal because it means there is a lot of assignments of 0s to coefficients that are already 0s (o = [jnp.zeros(out) for out in [(3, 2)]]) , which means a lot of useless operations if the result is a huge sparse matrix.
Nice catch! Actually I think that densify is the way to go. It should not change the result of the computation, but it somehow affects the representation of the output. I have tried to run original example on translation from casadi as here. It does not work properly as well.
I think densify should suffice as we are not expecting to perform translation in real-time, rather than we assume that we have some finite time on conversion and then execution performance is crucial.
I have added the test on this case here 0e5b0da. For now as we do not know better solution, let's densify it
Hello,
The conversion of a function from casadi to jax is broken when the symbolic result is a matrix with some structural zeros (when printing an expression with structural zeros, these will be represented as 00 to distinguish them from actual zeros 0).
Minimal example:
Results:
Symbolic expressions:
We can see that the indexes
at[([0, 1, 2, 0], [0, 0, 0, 1])
of the Jax function are wrong.If all coefficients are non-zero (i.e removing
A[1, :] = 0.0
), then the conversion is right:Best,
The text was updated successfully, but these errors were encountered: