Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion time for sparse matrix #19

Open
paLeziart opened this issue Dec 11, 2024 · 3 comments
Open

Conversion time for sparse matrix #19

paLeziart opened this issue Dec 11, 2024 · 3 comments

Comments

@paLeziart
Copy link

paLeziart commented Dec 11, 2024

Hello again!

Thank you for your quick fixes for #16 and #17 !

This is more of a discussion rather than an issue, so feel free to close it at some point.

As I have to convert big sparse matrices, I am facing long compilation time using convert(foo, compile=True) as it does not ignore the zeros and assign them to the output, meaning there is a ton of assignment operations even though there are few non-zeros coefficients.

For instance the following code takes 10 seconds to run for N = 100 or 76 seconds for N = 200 (and I will likely have to convert bigger matrices). I guess the compiler struggles with the high number of operations and does not realize it can ignore most of them?

N = 200
x = ca.SX.sym("x", N)
y = ca.SX.sym("y", (N, N)) * 0.0
for i in range(N):
    y[i, i] = 2 * x[i]
foo = ca.Function("foo", [x], [y])

start = time.time()
convert(foo, compile=True)
print("Convert time:", time.time() - start)

Note that this is just an example, of course for this diagonal matrix I could just directly code it in jax.

As a quick roundabout, I slightly modified the codegen function to ignore the operations associated with a zero value:

for layer in sorted_nodes:
        indices = []
        # MODIF HERE
        zero_nodes = []
        assignment = "["
        for node in layer:
            if len(graph[node]) == 0 and node not in output_map:
                continue
            if node in output_map:
                oo = output_map[node]
                if outputs.get(oo[0], None) is None:
                    outputs[oo[0]] = {"rows": [], "cols": [], "values": []}
                
                # MODIF HERE
                if int(values[node].split("]")[0][-1]) in zero_nodes:
                    continue

                outputs[oo[0]]["rows"].append(oo[1])
                outputs[oo[0]]["cols"].append(oo[2])
                outputs[oo[0]]["values"].append(values[node])
            else:
                if len(assignment) > 1:
                    assignment += ", "

                # MODIF HERE
                if "jnp" in values[node] and not any((char.isdigit() and char != '0') for char in values[node]):
                    zero_nodes.append(node)

                assignment += values[node]
                indices += [node]
        if len(indices) == 0:
            continue
        assignment += "]"
        code += f"    work = work.at[jnp.array({indices})].set({assignment})\n"

The tests seem to be all passing, and it brought down the convert time from 76 to less than 2 seconds for N = 200.

I'm sure there is a better way to do that, but I just wanted to share the code in case other people are facing the same struggles with spare matrices.

For completeness, the translation (for N = 4 just to illustrate) goes from

def evaluate_foo(*args):
    inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
    outputs = [jnp.zeros(out) for out in [(4, 4)]]
    work = jnp.zeros((26, 1))
    work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
    work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
    outputs[0] = outputs[0].at[([1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 1, 2, 3])].set([work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[2][0], work[10][0], work[17][0], work[24][0]])
    return outputs

to

def evaluate_foo(*args):
    inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
    outputs = [jnp.zeros(out) for out in [(4, 4)]]
    work = jnp.zeros((26, 1))
    work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
    work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
    outputs[0] = outputs[0].at[([0, 1, 2, 3], [0, 1, 2, 3])].set([work[2][0], work[10][0], work[17][0], work[24][0]])
    return outputs

We can see it just discards useless assignments of work[4][0] to outputs[0] without affecting the rest.

Best,

@mattephi
Copy link
Member

mattephi commented Dec 12, 2024

This sounds very good and perfectly aligns with our view on the tool. As we have presumably unlimited time on precomputations and need to generate as efficient function as possible, it makes sense to apply optimizations which make code more performant.

I suggest you to just create a pull request with your idea and contribute, we will be happy to accept this when we get a little bit more free time. Otherwise we will probably implement this ourselves, just a little bit later.

@paLeziart
Copy link
Author

@mattephi Thanks for your answer, I will make the pull request when I have some time 👍

It helped a bit to reduce the number of zeros but actually I am still struggling with the conversion times. This is unrelated to jaxadi though. The jax compile() after lowering the function takes a surprising amount of time, which is a bit of a pain when prototyping and debugging. Even with jax cache fully enabled, I have not found a way to avoid that yet...

@mattephi
Copy link
Member

The jax compile() after lowering the function takes a surprising amount of time

I think it is the main issue with jax as well as with other tools such as mjx. Compile times are horrendous and many people complain about them. There are some approaches to the problem like applying vmap before any jit and this kind of things, but there is no general solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants