diff --git a/burn-import/src/burn/node/base.rs b/burn-import/src/burn/node/base.rs index 9a082559ef..65a96776fe 100644 --- a/burn-import/src/burn/node/base.rs +++ b/burn-import/src/burn/node/base.rs @@ -1,8 +1,7 @@ use super::{ add::AddNode, batch_norm::BatchNormNode, constant::ConstantNode, conv2d::Conv2dNode, - equal::EqualNode, flatten::FlattenNode, linear::LinearNode, log_softmax::LogSoftmaxNode, - matmul::MatmulNode, max_pool2d::MaxPool2dNode, relu::ReLUNode, reshape::ReshapeNode, - sigmoid::SigmoidNode, + equal::EqualNode, linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, + reshape::ReshapeNode, unary::UnaryNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::record::PrecisionSettings; @@ -78,13 +77,10 @@ pub enum Node { MaxPool2d(MaxPool2dNode), Linear(LinearNode), BatchNorm(BatchNormNode), - ReLU(ReLUNode), - Flatten(FlattenNode), - LogSoftmax(LogSoftmaxNode), Constant(ConstantNode), Equal(EqualNode), + Unary(UnaryNode), Reshape(ReshapeNode), - Sigmoid(SigmoidNode), } macro_rules! match_all { @@ -96,13 +92,10 @@ macro_rules! match_all { Node::MaxPool2d(node) => $func(node), Node::Linear(node) => $func(node), Node::BatchNorm(node) => $func(node), - Node::ReLU(node) => $func(node), - Node::Flatten(node) => $func(node), - Node::LogSoftmax(node) => $func(node), Node::Constant(node) => $func(node), Node::Equal(node) => $func(node), Node::Reshape(node) => $func(node), - Node::Sigmoid(node) => $func(node), + Node::Unary(node) => $func(node), } }}; } @@ -126,12 +119,9 @@ impl Node { Node::MaxPool2d(_) => "max_pool2d", Node::Linear(_) => "linear", Node::BatchNorm(_) => "batch_norm", - Node::ReLU(_) => "relu", - Node::Flatten(_) => "flatten", - Node::LogSoftmax(_) => "log_softmax", Node::Equal(_) => "equal", Node::Reshape(_) => "reshape", - Node::Sigmoid(_) => "sigmoid", + Node::Unary(unary) => unary.kind.as_str(), } } } @@ -264,9 +254,9 @@ pub(crate) mod tests { module::Module, tensor::{backend::Backend, Tensor}, }; - use burn::nn::PaddingConfig2d; - use burn::nn::conv::Conv2d; use burn::nn::conv::Conv2dConfig; + use burn::nn::conv::Conv2d; + use burn::nn::PaddingConfig2d; #[derive(Module, Debug)] pub struct Model { diff --git a/burn-import/src/burn/node/flatten.rs b/burn-import/src/burn/node/flatten.rs deleted file mode 100644 index d45160c9d9..0000000000 --- a/burn-import/src/burn/node/flatten.rs +++ /dev/null @@ -1,64 +0,0 @@ -use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::record::PrecisionSettings; -use proc_macro2::TokenStream; -use quote::quote; - -#[derive(Debug, Clone, new)] -pub struct FlattenNode { - pub input: TensorType, - pub output: TensorType, - pub start_dim: usize, - pub end_dim: usize, -} - -impl NodeCodegen for FlattenNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let start_dim = self.start_dim.to_tokens(); - let end_dim = self.end_dim.to_tokens(); - - quote! { - let #output = #input.flatten(#start_dim, #end_dim); - } - } - - fn into_node(self) -> Node { - Node::Flatten(self) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::burn::{node::flatten::FlattenNode, TensorType}; - - use crate::burn::node::tests::codegen_unary_operator; - - #[test] - fn test_codegen_node() { - codegen_unary_operator::<4, _>( - FlattenNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - 1, - 2, - ), - quote! { - let tensor2 = tensor1.flatten(1, 2); - - tensor2 - }, - ); - } -} diff --git a/burn-import/src/burn/node/log_softmax.rs b/burn-import/src/burn/node/log_softmax.rs deleted file mode 100644 index e3e048c975..0000000000 --- a/burn-import/src/burn/node/log_softmax.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::record::PrecisionSettings; -use proc_macro2::TokenStream; -use quote::quote; - -#[derive(Debug, Clone, new)] -pub struct LogSoftmaxNode { - pub input: TensorType, - pub output: TensorType, - pub dim: usize, -} - -impl NodeCodegen for LogSoftmaxNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - let dim = self.dim.to_tokens(); - - quote! { - let #output = burn::tensor::activation::log_softmax(#input, #dim); - } - } - - fn into_node(self) -> Node { - Node::LogSoftmax(self) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::burn::{node::log_softmax::LogSoftmaxNode, TensorType}; - - use crate::burn::node::tests::codegen_unary_operator; - - #[test] - fn test_codegen_node() { - codegen_unary_operator::<4, _>( - LogSoftmaxNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - 1, - ), - quote! { - let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); - - tensor2 - }, - ); - } -} diff --git a/burn-import/src/burn/node/mod.rs b/burn-import/src/burn/node/mod.rs index f9457e1a56..c4e0485634 100644 --- a/burn-import/src/burn/node/mod.rs +++ b/burn-import/src/burn/node/mod.rs @@ -5,14 +5,11 @@ pub(crate) mod batch_norm; pub(crate) mod constant; pub(crate) mod conv2d; pub(crate) mod equal; -pub(crate) mod flatten; pub(crate) mod linear; -pub(crate) mod log_softmax; pub(crate) mod matmul; pub(crate) mod max_pool2d; -pub(crate) mod relu; pub(crate) mod reshape; -pub(crate) mod sigmoid; +pub(crate) mod unary; pub(crate) use base::*; diff --git a/burn-import/src/burn/node/relu.rs b/burn-import/src/burn/node/relu.rs deleted file mode 100644 index 1d5b7a608c..0000000000 --- a/burn-import/src/burn/node/relu.rs +++ /dev/null @@ -1,58 +0,0 @@ -use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, Type}; -use burn::record::PrecisionSettings; -use proc_macro2::TokenStream; -use quote::quote; - -#[derive(Debug, Clone, new)] -pub struct ReLUNode { - pub input: TensorType, - pub output: TensorType, -} - -impl NodeCodegen for ReLUNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - - quote! { - let #output = burn::tensor::activation::relu(#input); - } - } - - fn into_node(self) -> Node { - Node::ReLU(self) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::burn::{node::relu::ReLUNode, TensorType}; - - use crate::burn::node::tests::codegen_unary_operator; - - #[test] - fn test_codegen_node() { - codegen_unary_operator::<4, _>( - ReLUNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - ), - quote! { - let tensor2 = burn::tensor::activation::relu(tensor1); - - tensor2 - }, - ); - } -} diff --git a/burn-import/src/burn/node/sigmoid.rs b/burn-import/src/burn/node/sigmoid.rs deleted file mode 100644 index a024783ea8..0000000000 --- a/burn-import/src/burn/node/sigmoid.rs +++ /dev/null @@ -1,61 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; - -use burn::record::PrecisionSettings; - -use super::{Node, NodeCodegen}; - -use crate::burn::{Scope, TensorType, Type}; - -#[derive(Debug, Clone, new)] -pub struct SigmoidNode { - pub input: TensorType, - pub output: TensorType, -} - -impl NodeCodegen for SigmoidNode { - fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] - } - - fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] - } - - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; - - quote! { - let #output = #input.sigmoid(); - } - } - - fn into_node(self) -> Node { - Node::Sigmoid(self) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::burn::{node::sigmoid::SigmoidNode, TensorType}; - - use crate::burn::node::tests::codegen_unary_operator; - - #[test] - fn test_codegen_node() { - codegen_unary_operator::<4, _>( - SigmoidNode::new( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - ), - quote! { - let tensor2 = tensor1.sigmoid(); - - tensor2 - }, - ); - } -} diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs new file mode 100644 index 0000000000..34e415dc8d --- /dev/null +++ b/burn-import/src/burn/node/unary.rs @@ -0,0 +1,177 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; +use std::sync::Arc; + +// Simple fn pointer that receive input as a token stream and return function call. +type FnPointer = Arc TokenStream>; + +/// Node for all unary operators. +#[derive(Clone, new)] +pub struct UnaryNode { + pub input: TensorType, + pub output: TensorType, + pub kind: UnaryNodeKind, + function: FnPointer, +} + +/// Type of unary node. +#[derive(Clone)] +pub enum UnaryNodeKind { + Flatten, + Relu, + Sigmoid, + LogSoftmax, +} + +impl UnaryNodeKind { + pub fn as_str(&self) -> &str { + match self { + Self::Flatten => "flatten", + Self::Relu => "relu", + Self::Sigmoid => "sigmoid", + Self::LogSoftmax => "log_softmax", + } + } +} + +impl std::fmt::Debug for UnaryNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + format!( + "UnaryNode {{ input: {:?}, output: {:?}, name: {} }}", + self.input, + self.output, + self.kind.as_str() + ) + .as_str(), + ) + } +} + +impl NodeCodegen for UnaryNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(&self.output)] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(&self.input)] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let function = (self.function)(input); + + quote! { + let #output = #function; + } + } + + fn into_node(self) -> Node { + Node::Unary(self) + } +} + +impl UnaryNode { + pub(crate) fn flatten( + input: TensorType, + output: TensorType, + start_dim: usize, + end_dim: usize, + ) -> Self { + let start_dim = start_dim.to_tokens(); + let end_dim = end_dim.to_tokens(); + let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; + + Self::new(input, output, UnaryNodeKind::Flatten, Arc::new(function)) + } + + pub(crate) fn relu(input: TensorType, output: TensorType) -> Self { + let function = move |input| quote! { burn::tensor::activation::relu(#input) }; + Self::new(input, output, UnaryNodeKind::Relu, Arc::new(function)) + } + + pub(crate) fn sigmoid(input: TensorType, output: TensorType) -> Self { + let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; + Self::new(input, output, UnaryNodeKind::Sigmoid, Arc::new(function)) + } + + pub(crate) fn log_softmax(input: TensorType, output: TensorType, dim: usize) -> Self { + let dim = dim.to_tokens(); + let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; + Self::new(input, output, UnaryNodeKind::LogSoftmax, Arc::new(function)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::node::tests::codegen_unary_operator; + use crate::burn::TensorType; + + #[test] + fn test_unary_codegen_flatten() { + codegen_unary_operator::<4, _>( + UnaryNode::flatten( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + 1, + 2, + ), + quote! { + let tensor2 = tensor1.flatten(1, 2); + + tensor2 + }, + ); + } + + #[test] + fn test_unary_codegen_relu() { + codegen_unary_operator::<4, _>( + UnaryNode::relu( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ), + quote! { + let tensor2 = burn::tensor::activation::relu(tensor1); + + tensor2 + }, + ); + } + + #[test] + fn test_unary_codegen_sigmoid() { + codegen_unary_operator::<4, _>( + UnaryNode::sigmoid( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ), + quote! { + let tensor2 = burn::tensor::activation::sigmoid(tensor1); + + tensor2 + }, + ); + } + + #[test] + fn test_unary_codegen_log_softmax() { + codegen_unary_operator::<4, _>( + UnaryNode::log_softmax( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + 1, + ), + quote! { + let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); + + tensor2 + }, + ); + } +} diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index f711ceb8f9..326ff9a9b2 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -18,14 +18,11 @@ use crate::{ constant::{ConstantNode, ConstantValue}, conv2d::Conv2dNode, equal::EqualNode, - flatten::FlattenNode, linear::LinearNode, - log_softmax::LogSoftmaxNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, - relu::ReLUNode, reshape::ReshapeNode, - sigmoid::SigmoidNode, + unary::UnaryNode, }, TensorType, }, @@ -216,11 +213,19 @@ impl ONNXGraph { EqualNode::new(lhs, rhs, output) } - fn relu_conversion(node: Node) -> ReLUNode { + fn relu_conversion(node: Node) -> UnaryNode { let input = node.inputs.get(0).unwrap().to_tensor_type(); let output = node.outputs.get(0).unwrap().to_tensor_type(); - ReLUNode::new(input, output) + UnaryNode::relu(input, output) + } + + fn flatten_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.get(0).unwrap().to_tensor_type(); + let (start_dim, end_dim) = flatten_config(&node); + + UnaryNode::flatten(input, output, start_dim, end_dim) } fn reshape_conversion(mut node: Node) -> ReshapeNode { @@ -235,27 +240,19 @@ impl ONNXGraph { ) } - fn flatten_conversion(node: Node) -> FlattenNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); - let (start_dim, end_dim) = flatten_config(&node); - - FlattenNode::new(input, output, start_dim, end_dim) - } - - fn sigmoid_conversion(node: Node) -> SigmoidNode { + fn sigmoid_conversion(node: Node) -> UnaryNode { let input = node.inputs.get(0).unwrap().to_tensor_type(); let output = node.outputs.get(0).unwrap().to_tensor_type(); - SigmoidNode::new(input, output) + UnaryNode::sigmoid(input, output) } - fn log_softmax_conversion(node: Node) -> LogSoftmaxNode { + fn log_softmax_conversion(node: Node) -> UnaryNode { let input = node.inputs.get(0).unwrap().to_tensor_type(); let output = node.outputs.get(0).unwrap().to_tensor_type(); let dim = log_softmax_config(&node); - LogSoftmaxNode::new(input, output, dim) + UnaryNode::log_softmax(input, output, dim) } fn linear_conversion(mut node: Node) -> LinearNode {