diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md new file mode 100644 index 000000000..5ba378f89 --- /dev/null +++ b/candle-examples/examples/rnn/README.md @@ -0,0 +1,167 @@ +# candle-rnn: Recurrent Neural Network + +This example demonstrates how to use the `candle_nn::rnn` crate to run LSTM and GRU, including bidirection and multi-layer. + +## Running the example + +```bash +$ cargo run --example rnn --release +``` + +Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, and to enable bidirectional via `--bidirection`. + +```bash +$ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection +``` + +## Running the example test + +Add argument `--test` to run test of this example. + +```bash +$ cargo run --example rnn --release -- --test +``` + +Test models are generated with reference to the Pytorch examples [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn). + +Test models are generated by the following codes: + +- lstm_test.pt: A simple LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_test.pt") + ``` + +- gru_test.pt: A simple GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_test.pt") + ``` + +- bi_lstm_test.pt: A bidirectional LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_test.pt") + ``` + +- bi_gru_test.pt: A bidirectional GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_test.pt") + ``` + +- lstm_nlayer_test.pt: A LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_nlayer_test.pt") + ``` + +- bi_lstm_nlayer_test.pt: A bidirectional LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_nlayer_test.pt") + ``` + +- gru_nlayer_test.pt: A GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_nlayer_test.pt") + ``` + +- bi_gru_nlayer_test.pt: A bidirectional GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_nlayer_test.pt") + ``` diff --git a/candle-examples/examples/rnn/main.rs b/candle-examples/examples/rnn/main.rs new file mode 100644 index 000000000..83364a8ad --- /dev/null +++ b/candle-examples/examples/rnn/main.rs @@ -0,0 +1,401 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, Device, Tensor, D}; +use candle_nn::{rnn, LSTMConfig, VarBuilder, RNN}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +const ACCURACY: f32 = 1e-6; + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum WhichModel { + #[value(name = "lstm")] + LSTM, + #[value(name = "gru")] + GRU, +} + +#[derive(Clone, Copy, Debug, Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + cpu: bool, + + #[arg(long, default_value_t = 10)] + input_dim: usize, + + #[arg(long, default_value_t = 20)] + hidden_dim: usize, + + #[arg(long, default_value_t = 1)] + layers: usize, + + #[arg(long)] + bidirection: bool, + + #[arg(long, default_value_t = 5)] + batch_size: usize, + + #[arg(long, default_value_t = 3)] + seq_len: usize, + + #[arg(long, default_value = "lstm")] + model: WhichModel, + + #[arg(long)] + test: bool, +} + +impl Args { + pub fn load_model(&self) -> Result<(Config, VarBuilder<'static>, Tensor)> { + let device = self.device()?; + if self.test { + // run unit test and download model from huggingface hub. + let model = match self.model { + WhichModel::LSTM => "lstm", + WhichModel::GRU => "gru", + }; + + let bidirection = if self.bidirection { "bi_" } else { "" }; + let layer = if self.layers > 1 { "_nlayer" } else { "" }; + let model = format!("{}{}{}_test", bidirection, model, layer); + let (config, vb) = load_model(&model, &device)?; + let input = vb.get( + (config.batch_size, config.sequence_length, config.input), + "input", + )?; + Ok((config, vb, input)) + } else { + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let input = Tensor::randn( + 0.0_f32, + 1.0, + (self.batch_size, self.seq_len, self.input_dim), + &device, + )?; + Ok((self.into(), vb, input)) + } + } + + pub fn device(&self) -> Result { + Ok(candle_examples::device(self.cpu)?) + } +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +struct Config { + pub input: usize, + pub batch_size: usize, + pub sequence_length: usize, + pub hidden: usize, + pub layers: usize, + pub bidirection: bool, +} + +impl From<&Args> for Config { + fn from(args: &Args) -> Self { + Config { + input: args.input_dim, + batch_size: args.batch_size, + sequence_length: args.seq_len, + hidden: args.hidden_dim, + layers: args.layers, + bidirection: args.bidirection, + } + } +} + +fn load_model(model: &str, device: &Device) -> Result<(Config, VarBuilder<'static>)> { + let api = Api::new()?; + let repo_id = "kigichang/test_rnn".to_string(); + let repo = api.repo(Repo::with_revision( + repo_id, + RepoType::Model, + "main".to_string(), + )); + + let filename = repo.get(&format!("{}.pt", model))?; + let config_file = repo.get(&format!("{}.json", model))?; + + let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; + let vb = VarBuilder::from_pth(filename, DType::F32, device)?; + + Ok((config, vb)) +} + +fn assert_tensor(a: &Tensor, b: &Tensor, v: f32) -> Result<()> { + assert_eq!(a.dims(), b.dims()); + let dim = a.dims().len(); + let mut t = (a - b)?.abs()?; + + for _i in 0..dim { + t = t.max(D::Minus1)?; + } + + let t = t.to_scalar::()?; + println!("max diff = {}", t); + assert!(t < v); + Ok(()) +} + +fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig { + let mut config = LSTMConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig { + let mut config = rnn::GRUConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn run_lstm(args: Args) -> Result { + let (config, vb, mut input) = args.load_model()?; + + let mut layers = Vec::with_capacity(config.layers); + + for layer_idx in 0..config.layers { + let input_dim = if layer_idx == 0 { + config.input + } else { + config.hidden + }; + let lstm_config = lstm_config(layer_idx, rnn::Direction::Forward); + let lstm = candle_nn::lstm(input_dim, config.hidden, lstm_config, vb.clone())?; + layers.push(lstm); + } + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + + Ok(input) +} + +fn run_gru(args: Args) -> Result { + let (config, vb, mut input) = args.load_model()?; + + let mut layers = Vec::with_capacity(config.layers); + + for layer_idx in 0..config.layers { + let input_dim = if layer_idx == 0 { + config.input + } else { + config.hidden + }; + let gru_config = gru_config(layer_idx, rnn::Direction::Forward); + let gru = candle_nn::gru(input_dim, config.hidden, gru_config, vb.clone())?; + layers.push(gru); + } + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + + Ok(input) +} + +fn run_bidirectional_lstm(args: Args) -> Result { + let (config, vb, mut input) = args.load_model()?; + + let mut layers = Vec::with_capacity(config.layers); + + for layer_idx in 0..config.layers { + let input_dim = if layer_idx == 0 { + config.input + } else { + config.hidden * 2 + }; + + let forward_config = lstm_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::lstm(input_dim, config.hidden, forward_config, vb.clone())?; + + let backward_config = lstm_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::lstm(input_dim, config.hidden, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + + Ok(input) +} + +fn run_bidirectional_gru(args: Args) -> Result { + let (config, vb, mut input) = args.load_model()?; + + let mut layers = Vec::with_capacity(config.layers); + for layer_idx in 0..config.layers { + let input_dim = if layer_idx == 0 { + config.input + } else { + config.hidden * 2 + }; + + let forward_config = gru_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::gru(input_dim, config.hidden, forward_config, vb.clone())?; + + let backward_config = gru_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::gru(input_dim, config.hidden, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + + Ok(input) +} + +fn main() -> Result<()> { + let args = Args::parse(); + + println!( + "Running {:?} bidirection: {} layers: {} example-test: {}", + args.model, args.bidirection, args.layers, args.test + ); + + if args.test { + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing LSTM with 1 layer: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing GRU with 1 layer: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional LSTM with 1 layer: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional GRU with 1 layer: "); + run_bidirectional_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing LSTM with 3 layers: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing GRU with 3 layers: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional LSTM with 3 layers: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional GRU with 3 layers: "); + run_bidirectional_gru(test_args)?; + } else { + let num_directions = if args.bidirection { 2 } else { 1 }; + let batch_size = args.batch_size; + let seq_len = args.seq_len; + let hidden_dim = args.hidden_dim; + + let output = match (args.model, args.bidirection) { + (WhichModel::LSTM, false) => run_lstm(args), + (WhichModel::GRU, false) => run_gru(args), + (WhichModel::LSTM, true) => run_bidirectional_lstm(args), + (WhichModel::GRU, true) => run_bidirectional_gru(args), + }?; + + assert_eq!( + output.dims3()?, + (batch_size, seq_len, hidden_dim * num_directions) + ); + println!("result dims: {:?}", output.dims()); + } + + Ok(()) +} diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 798db6ac4..e3f3e1bac 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -6,6 +6,9 @@ use candle::{DType, Device, IndexOp, Result, Tensor}; pub trait RNN { type State: Clone; + /// Returns the direction of the RNN. + fn direction(&self) -> Direction; + /// A zero state from which the recurrent network is usually initialized. fn zero_state(&self, batch_dim: usize) -> Result; @@ -31,7 +34,12 @@ pub trait RNN { let (_b_size, seq_len, _features) = input.dims3()?; let mut output = Vec::with_capacity(seq_len); for seq_index in 0..seq_len { - let input = input.i((.., seq_index, ..))?.contiguous()?; + let index = if self.direction() == Direction::Forward { + seq_index + } else { + seq_len - seq_index - 1 + }; + let input = input.i((.., index, ..))?.contiguous()?; let state = if seq_index == 0 { self.step(&input, init_state)? } else { @@ -39,11 +47,21 @@ pub trait RNN { }; output.push(state); } + if self.direction() == Direction::Backward { + output.reverse(); + } Ok(output) } /// Converts a sequence of state to a tensor. fn states_to_tensor(&self, states: &[Self::State]) -> Result; + + /// Combines forward and backward states to a tensor. + fn bidirectional_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result; } /// The state for a LSTM network, this contains two tensors. @@ -70,7 +88,7 @@ impl LSTMState { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Direction { Forward, Backward, @@ -198,6 +216,10 @@ pub fn lstm( impl RNN for LSTM { type State = LSTMState; + fn direction(&self) -> Direction { + self.config.direction + } + fn zero_state(&self, batch_dim: usize) -> Result { let zeros = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; @@ -236,6 +258,22 @@ impl RNN for LSTM { let states = states.iter().map(|s| s.h.clone()).collect::>(); Tensor::stack(&states, 1) } + + fn bidirectional_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result { + let combine_states = forward_states + .iter() + .zip(backward_states.iter()) + .collect::>(); + let mut states = Vec::with_capacity(combine_states.len()); + for (f, b) in combine_states { + states.push(Tensor::cat(&[&f.h, &b.h], 1)?); + } + Tensor::stack(&states, 1) + } } /// The state for a GRU network, this contains a single tensor. @@ -259,6 +297,8 @@ pub struct GRUConfig { pub w_hh_init: super::Init, pub b_ih_init: Option, pub b_hh_init: Option, + pub layer_idx: usize, + pub direction: Direction, } impl Default for GRUConfig { @@ -268,6 +308,8 @@ impl Default for GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), + layer_idx: 0, + direction: Direction::Forward, } } } @@ -279,6 +321,8 @@ impl GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: None, b_hh_init: None, + layer_idx: 0, + direction: Direction::Forward, } } } @@ -307,22 +351,35 @@ impl GRU { config: GRUConfig, vb: crate::VarBuilder, ) -> Result { + let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; let w_ih = vb.get_with_hints( (3 * hidden_dim, in_dim), - "weight_ih_l0", // Only a single layer is supported. + &format!("weight_ih_l{layer_idx}{direction_str}"), config.w_ih_init, )?; let w_hh = vb.get_with_hints( (3 * hidden_dim, hidden_dim), - "weight_hh_l0", // Only a single layer is supported. + &format!("weight_hh_l{layer_idx}{direction_str}"), config.w_hh_init, )?; let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + Some(init) => Some(vb.get_with_hints( + 3 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + Some(init) => Some(vb.get_with_hints( + 3 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; Ok(Self { @@ -354,6 +411,10 @@ pub fn gru( impl RNN for GRU { type State = GRUState; + fn direction(&self) -> Direction { + self.config.direction + } + fn zero_state(&self, batch_dim: usize) -> Result { let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; @@ -383,6 +444,22 @@ impl RNN for GRU { fn states_to_tensor(&self, states: &[Self::State]) -> Result { let states = states.iter().map(|s| s.h.clone()).collect::>(); - Tensor::cat(&states, 1) + Tensor::stack(&states, 1) + } + + fn bidirectional_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result { + let combine_states = forward_states + .iter() + .zip(backward_states.iter()) + .collect::>(); + let mut states = Vec::with_capacity(combine_states.len()); + for (f, b) in combine_states { + states.push(Tensor::cat(&[&f.h, &b.h], 1)?); + } + Tensor::stack(&states, 1) } }