Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rnn classifies mnist #645

Open
never-to-never opened this issue May 4, 2023 · 3 comments
Open

rnn classifies mnist #645

never-to-never opened this issue May 4, 2023 · 3 comments

Comments

@never-to-never
Copy link

test.txt
I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

@IanQS
Copy link

IanQS commented May 4, 2023

  1. Please paste your code as opposed to attaching it as a file, especially if the code is short.

  2. Why are you using an RNN?

  3. From looking at your code, it doesn't seem like you're really using the time component. Are you sure that in your preprocessing you're replicating the data over the time axis? A

@never-to-never
Copy link
Author

never-to-never commented May 5, 2023

import haiku as hk
import jax
import jax.numpy as jnp
import optax
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

sequence_length = 28
input_size = 28
hidden_size = 128
num_classes = 10
batch_size = 128
num_epochs = 30
learning_rate = 0.001

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

def unroll_net(seqs: jax.Array):
    core = hk.LSTM(128)
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    return hk.Linear(10)(outs[-1]), state

model = hk.transform(unroll_net)

rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3)

@jax.jit
def loss(params, x, y):
  pred, _ = model.apply(params, None, x)
  return jnp.mean(jnp.square(pred - y))

@jax.jit
def accuracy(predict, target):
    return jnp.sum(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))

@jax.jit
def update(step, params, opt_state, x, y):
    l, grads = jax.value_and_grad(loss)(params, x, y)
    grads, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, grads)
    return l, params, opt_state

train_ds = iter(train_loader)
valid_ds = iter(test_loader)
sample_x, _ = next(train_ds)
sample_x = sample_x.reshape(sequence_length, -1, input_size)
sample_x = jnp.asarray(sample_x)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
length = len(train_ds)

for step in range(length-1):
    if step % 10 == 0:
        x, y = next(valid_ds)
        x = x.reshape(sequence_length, -1, input_size)
        x = jnp.asarray(x)
        y = jnp.asarray(y)
        y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
        print("Step {}: valid loss {}".format(step, loss(params, x, y)))
    x, y = next(train_ds)
    x = x.reshape(sequence_length, -1, input_size)
    x = jnp.asarray(x)
    y = jnp.asarray(y)
    y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
    train_loss, params, opt_state = update(step, params, opt_state, x, y)
    if step % 10 == 0:
        print("Step {}: train loss {}".format(step, train_loss))

Here is the full code.

@Ekundayo39283
Copy link

Ekundayo39283 commented Apr 12, 2024

test.txt
I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

The error in the code is likely due to the mismatch between PyTorch tensors and JAX arrays.

The train_loader and test_loader provide PyTorch tensors, while the model and loss functions expect JAX arrays. You need to convert the PyTorch tensors to JAX arrays before passing them to the model and loss functions. Use

jnp.array()

to convert PyTorch tensors to JAX arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants