Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with nuclear norm, safe_map() argument 2 is longer than argument 1 #904

Closed
JadM133 opened this issue Dec 3, 2024 · 5 comments
Closed

Comments

@JadM133
Copy link

JadM133 commented Dec 3, 2024

Related to #716, I am having the same error when trying to put a nuclear norm of a model weight in the loss.

Minimal reproducible example:

step_st = [3000, 3000]
lr_st = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]
batch_size_st = [16, 16, 16, 16, 32]
training_key = jrandom.key(0)

model = eqx.nn.MLP(5, 1, 64, 2, key=jrandom.key(10))
input = jrandom.normal(jrandom.key(0), (100, 5))
output = input**2

@eqx.filter_value_and_grad
def loss_fun(model, inp, out):
    pr = jax.vmap(model)(inp)
    loss = jnp.mean((pr - out) ** 2)
    return loss + jnp.linalg.norm(model.layers[0].weight, "nuc")

@eqx.filter_jit
def make_step(model, inp, out, opt_state):
    loss, grads = loss_fun(model, inp, out)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
    t_t = 0
    optim = optax.adabelief(lr)
    filtered_model = eqx.filter(model, eqx.is_inexact_array)
    opt_state = optim.init(filtered_model)

    for step, (input_b, out) in zip(
        range(steps),
        dataloader(
            [input, output],
            batch_size,
            key=training_key,
        ),
    ):

        loss, model, opt_state = make_step(model, input_b, out, opt_state)
        print("Working...")

dataloader is a typical dataloader.

Note: It works fine for the Frobenius norm (default).

I am unsure if it is a problem with equinox or the same bug in the new version of JAX.

@JadM133
Copy link
Author

JadM133 commented Dec 3, 2024

Also, if it helps, these are the versions:
jax: 0.4.35
equinox: 0.11.8

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 3, 2024

Can you give a standalone script that can be ran? Including imports etc.

@JadM133
Copy link
Author

JadM133 commented Dec 3, 2024

Sure, here it is:

import jax.random as jrandom
import equinox as eqx
import optax
import jax

import jax.numpy as jnp
from operator import itemgetter


def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    arrays = [array if array is not None else [None] * dataset_size for array in arrays]
    indices = jnp.arange(dataset_size)

    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            arrs = tuple(
                itemgetter(*batch_perm)(array) for array in arrays
            )  # Works for lists and arrays
            yield [jnp.array(arr) for arr in arrs]
            start = end
            end = start + batch_size


if __name__ == "__main__":

    step_st = [3000, 3000]
    lr_st = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]
    batch_size_st = [16, 16, 16, 16, 32]
    training_key = jrandom.key(0)

    model = eqx.nn.MLP(5, 1, 64, 2, key=jrandom.key(10))
    input = jrandom.normal(jrandom.key(0), (100, 5))
    output = input**2

    @eqx.filter_value_and_grad
    def loss_fun(model, inp, out):
        pr = jax.vmap(model)(inp)
        loss = jnp.mean((pr - out) ** 2)
        return loss + jnp.linalg.norm(model.layers[0].weight, "nuc")

    @eqx.filter_jit
    def make_step(model, inp, out, opt_state):
        loss, grads = loss_fun(model, inp, out)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
        t_t = 0
        optim = optax.adabelief(lr)
        filtered_model = eqx.filter(model, eqx.is_inexact_array)
        opt_state = optim.init(filtered_model)

        for step, (input_b, out) in zip(
            range(steps),
            dataloader(
                [input, output],
                batch_size,
                key=training_key,
            ),
        ):

            loss, model, opt_state = make_step(model, input_b, out, opt_state)
            print("Working...")

If you remove the "nuc" option, the code will run, otherwise, you'll get the error previously mentioned.

@johannahaffner
Copy link

Hi Jad,
I cannot reproduce this error with JAX 0.4.35, any of the three latest equinox versions (up to 0.11.9), and optax 0.2.4 or 0.2.3. For me it runs fine.

@JadM133
Copy link
Author

JadM133 commented Dec 3, 2024

Hello @johannahaffner, thanks for the prompt and useful reply.

Uninstalling and reinstalling jax resolved the issue. I am not sure what happened there, but it works now some I am closing the issue.

@JadM133 JadM133 closed this as completed Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants