Skip to content

Commit

Permalink
add test for example to confirm that the results are similar to pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
kigichang committed Oct 11, 2024
1 parent d16a45e commit 54d8819
Show file tree
Hide file tree
Showing 2 changed files with 410 additions and 83 deletions.
152 changes: 152 additions & 0 deletions candle-examples/examples/rnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,155 @@ Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, a
```bash
$ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection
```

## Running the example test

Add argument `--test` to run test of this example.

```bash
$ cargo run --example rnn --release -- --test
```

Test models are generated by Pytorch [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn).

Test models are generated by the following codes:

- lstm_test.pt: A simple LSTM model with 1 layer.

```python
import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, num_layers=1, batch_first=True)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "lstm_test.pt")
```

- gru_test.pt: A simple GRU model with 1 layer.

```python
import torch
import torch.nn as nn

rnn = nn.GRU(10, 20, num_layers=1, batch_first=True)
input = torch.randn(5, 3, 10)
output, hn = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
torch.save(state_dict, "gru_test.pt")
```

- bi_lstm_test.pt: A bidirectional LSTM model with 1 layer.

```python
import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, num_layers=1, bidirectional=True, batch_first=True)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "bi_lstm_test.pt")
```

- bi_gru_test.pt: A bidirectional GRU model with 1 layer.

```python
import torch
import torch.nn as nn

rnn = nn.GRU(10, 20, num_layers=1, bidirectional=True, batch_first=True)
input = torch.randn(5, 3, 10)
output, hn = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
torch.save(state_dict, "bi_gru_test.pt")
```

- lstm_nlayer_test.pt: A LSTM model with 3 layers.

```python
import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, num_layers=3, batch_first=True)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "lstm_nlayer_test.pt")
```

- bi_lstm_nlayer_test.pt: A bidirectional LSTM model with 3 layers.

```python
import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, num_layers=3, bidirectional=True, batch_first=True)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "bi_lstm_nlayer_test.pt")
```

- gru_nlayer_test.pt: A GRU model with 3 layers.

```python
import torch
import torch.nn as nn

rnn = nn.GRU(10, 20, num_layers=3, batch_first=True)
input = torch.randn(5, 3, 10)
output, hn = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
torch.save(state_dict, "gru_nlayer_test.pt")
```

- bi_gru_nlayer_test.pt: A bidirectional GRU model with 3 layers.

```python
import torch
import torch.nn as nn

rnn = nn.GRU(10, 20, num_layers=3, bidirectional=True, batch_first=True)
input = torch.randn(5, 3, 10)
output, hn = rnn(input)

state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output.contiguous()
state_dict['hn'] = hn
torch.save(state_dict, "bi_gru_nlayer_test.pt")
```
Loading

0 comments on commit 54d8819

Please sign in to comment.