Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion is wrong when result contains structural 0s #17

Closed
paLeziart opened this issue Dec 6, 2024 · 2 comments
Closed

Conversion is wrong when result contains structural 0s #17

paLeziart opened this issue Dec 6, 2024 · 2 comments
Assignees

Comments

@paLeziart
Copy link

paLeziart commented Dec 6, 2024

Hello,

The conversion of a function from casadi to jax is broken when the symbolic result is a matrix with some structural zeros (when printing an expression with structural zeros, these will be represented as 00 to distinguish them from actual zeros 0).

Minimal example:

import casadi as ca
import jax
import jaxadi
import numpy as np

X = ca.SX.sym("x", 2)
A = np.random.randn(3, 2)
A[1, :] = 0.0
Y = ca.jacobian(A @ X, X)

# Create CasADi function
ca_foo = ca.Function("foo", [X], [Y], ["X"], ["Y"])

# Convert CasADi function to JAX
jax_foo = jaxadi.convert(ca_foo, compile=True)

# Let's test for some X
X0 = np.random.randn(2, 1)

print("casadi result:", ca_foo(X0))
print("jax result:\n", jax_foo(X0)[0])

print("casadi symbolic: ", Y)
print("jax translation: ", jaxadi.translate(ca_foo))

Results:

casadi result: 
[[0.65246, -1.23352], 
 [00, 00], 
 [1.02276, -0.815194]]

jax result:
 [[ 0.6524603  -0.81519425]
 [ 1.0227624   0.        ]
 [-1.2335179   0.        ]]

Symbolic expressions:

casadi symbolic:  
[[0.65246, -1.23352], 
 [00, 00], 
 [1.02276, -0.815194]]

jax translation: 
def evaluate_foo(*args):
    inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]
    o = [jnp.zeros(out) for out in [(3, 2)]]
    o[0] = o[0].at[([0, 1, 2, 0], [0, 0, 0, 1])].set([jnp.array([0.6524602696421827])[0], jnp.array([1.0227623641242620])[0], jnp.array([-1.2335179321922756])[0], jnp.array([-0.8151942704524572])[0]])
    return o

We can see that the indexes at[([0, 1, 2, 0], [0, 0, 0, 1]) of the Jax function are wrong.

If all coefficients are non-zero (i.e removing A[1, :] = 0.0), then the conversion is right:

casadi result: 
[[0.44682, 0.0795599], 
 [-0.856165, 0.137537], 
 [-0.536343, 1.17636]]
jax result:
 [[ 0.44681963  0.07955994]
 [-0.85616523  0.1375372 ]
 [-0.5363429   1.1763588 ]]

Best,

@paLeziart paLeziart changed the title Conversion is wrong when result contains 0 coefficients Conversion is wrong when result contains structural 0s Dec 6, 2024
@paLeziart
Copy link
Author

paLeziart commented Dec 6, 2024

A quick fix would be to call casadi.densify() internally on all casadi expressions before converting them so that structural 0s become standard 0s, which are properly handled. But there might be some side effects of densify() that I am not aware of.

It is also not optimal because it means there is a lot of assignments of 0s to coefficients that are already 0s (o = [jnp.zeros(out) for out in [(3, 2)]]) , which means a lot of useless operations if the result is a huge sparse matrix.

@mattephi
Copy link
Member

mattephi commented Dec 7, 2024

casadi.densify()

Nice catch! Actually I think that densify is the way to go. It should not change the result of the computation, but it somehow affects the representation of the output. I have tried to run original example on translation from casadi as here. It does not work properly as well.

I think densify should suffice as we are not expecting to perform translation in real-time, rather than we assume that we have some finite time on conversion and then execution performance is crucial.

I have added the test on this case here 0e5b0da. For now as we do not know better solution, let's densify it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants