From 94db0420399ffc83212d0a43580f03eede74f010 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 10 Apr 2024 13:51:46 -0400 Subject: [PATCH 1/2] Add reduce mean onnx op support --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 17 +++++ .../tests/reduce_mean/reduce_mean.onnx | Bin 0 -> 394 bytes .../tests/reduce_mean/reduce_mean.py | 46 +++++++++++++ crates/burn-import/src/burn/node/unary.rs | 65 ++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 8 ++- .../burn-import/src/onnx/op_configuration.rs | 44 ++++++++++++ crates/burn-import/src/onnx/to_burn.rs | 10 +++ 9 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/reduce_mean/reduce_mean.onnx create mode 100755 crates/burn-import/onnx-tests/tests/reduce_mean/reduce_mean.py diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 0185517dd4..30446e7e50 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -140,7 +140,7 @@ represent the corresponding Burn Op. | [ReduceLogSum][133] | ❌ | ❌ | | [ReduceLogSumExp][134] | ❌ | ❌ | | [ReduceMax][135] | ❌ | ✅ | -| [ReduceMean][136] | ❌ | ✅ | +| [ReduceMean][136] | ✅ | ✅ | | [ReduceMin][137] | ❌ | ✅ | | [ReduceProd][138] | ❌ | ❌ | | [ReduceSum][139] | ❌ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index f6ddadcc9f..19eaf6a342 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -35,6 +35,7 @@ fn main() { .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reshape/reshape.onnx") .input("tests/sigmoid/sigmoid.onnx") .input("tests/softmax/softmax.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index c0faf32882..86eab82aaa 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -41,6 +41,7 @@ include_models!( mul, neg, recip, + reduce_mean, relu, reshape, sigmoid, @@ -442,6 +443,22 @@ mod tests { output3.to_data().assert_approx_eq(&expected3, 3); } + #[test] + fn reduce_mean() { + let device = Default::default(); + let model: reduce_mean::Model = reduce_mean::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); + let (output_scalar, output_tensor, output_value) = model.forward(input.clone()); + let expected_scalar = Data::from([9.75]); + let expected = Data::from([[[[9.75]]]]); + + assert_eq!(output_scalar.to_data(), expected_scalar); + assert_eq!(output_tensor.to_data(), input.to_data()); + assert_eq!(output_value.to_data(), expected); + } + #[test] fn reshape() { // Initialize the model without weights (because the exported file does not contain them) diff --git a/crates/burn-import/onnx-tests/tests/reduce_mean/reduce_mean.onnx b/crates/burn-import/onnx-tests/tests/reduce_mean/reduce_mean.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e08c8ec4c5e8e3c34d6cbbb333464e565ff8cf5e GIT binary patch literal 394 zcmd~ID8;Rh z5K-bnu(kNPII>ey3sN$3izOHqFfwrkU^l`@iWkX "log", Self::LogSoftmax => "log_softmax", Self::Neg => "neg", + Self::ReduceMean => "reduce_mean", Self::Reciprocal => "reciprocal", Self::LeakyRelu => "leaky_relu", Self::Relu => "relu", @@ -246,6 +248,32 @@ impl UnaryNode { _ => panic!("output must be a tensor"), } } + + pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option) -> Self { + // ReduceMean is constrained to numeric tensors, so no need to check for bool. + if let Type::Tensor(_) = output { + if let Some(dim) = dim { + // ReduceMean, keepdims=1, axes=[dim] + let dim = dim.to_tokens(); + Self::new( + input, + output, + UnaryNodeKind::ReduceMean, + Rc::new(move |input| quote! { #input.mean_dim(#dim) }), + ) + } else { + // ReduceMean, keepdims=0, axes=None + Self::new( + input, + output, + UnaryNodeKind::ReduceMean, + Rc::new(move |input| quote! { #input.mean() }), + ) + } + } else { + panic!("ReduceMean only supports tensor output"); + } + } } #[cfg(test)] @@ -430,6 +458,43 @@ mod tests { ); } + #[test] + fn test_unary_codegen_reduce_mean() { + one_node_graph( + UnaryNode::reduce_mean( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Some(1), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.mean_dim(1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + + one_node_graph( + UnaryNode::reduce_mean( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 1)), + None, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.mean(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + #[test] fn test_unary_codegen_reciprocal() { one_node_graph( diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 0abb739600..2a4fa9cfe7 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -39,7 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Reciprocal => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), + NodeType::ReduceMean => reduce_mean_update_outputs(node), NodeType::Relu => same_as_input(node), NodeType::Reshape => reshape_update_outputs(node), NodeType::Shape => shape_update_outputs(node), @@ -205,12 +205,11 @@ fn reshape_update_outputs(node: &mut Node) { }); } -fn mean_update_outputs(node: &mut Node) { +fn reduce_mean_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Mean: multiple inputs are not supported"); } - // Extract the configuration of the linear layer (inputs are known) let node_input = &mut node.inputs[0]; let tensor = match node_input.clone().ty { ArgType::Tensor(tensor) => tensor, @@ -229,6 +228,9 @@ fn mean_update_outputs(node: &mut Node) { if dim_only { node.outputs[0].ty = ArgType::Tensor(tensor); } else { + // NOTE: ReduceMax w/o keepdims reduces to a scalar value, but Burn doesn't have + // 0-dim tensor so we can't track or perform other ops on that value + // node.outputs[0].ty = ArgType::Scalar(tensor.elem_type); node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 419be15c49..0a60162401 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -660,3 +660,47 @@ fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { panic!("Padding configuration ({:?}) not supported", pads); } } + +pub fn reduce_mean_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.dim as i64; + } + Some(dim as usize) + } +} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 7ca2d6c94a..665dafa245 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -252,6 +252,7 @@ impl OnnxGraph { NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)), NodeType::Tanh => graph.register(Self::tanh_conversion(node)), NodeType::Constant => graph.register(Self::constant_conversion::(node)), + NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), @@ -462,6 +463,15 @@ impl OnnxGraph { ReshapeNode::new(input, output, shape) } + + fn reduce_mean_conversion(node: Node) -> UnaryNode { + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + let dim = reduce_mean_config(&node); + + UnaryNode::reduce_mean(input, output, dim) + } + fn unsqueeze_conversion(node: Node) -> UnsqueezeNode { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); From 6805081c657a39612699e8b456a8d81dbda538ad Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 15 Apr 2024 11:59:59 -0400 Subject: [PATCH 2/2] Fix comment --- crates/burn-import/src/onnx/dim_inference.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 2a4fa9cfe7..3195c09870 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -228,7 +228,7 @@ fn reduce_mean_update_outputs(node: &mut Node) { if dim_only { node.outputs[0].ty = ArgType::Tensor(tensor); } else { - // NOTE: ReduceMax w/o keepdims reduces to a scalar value, but Burn doesn't have + // NOTE: ReduceMean w/o keepdims reduces to a scalar value, but Burn doesn't have // 0-dim tensor so we can't track or perform other ops on that value // node.outputs[0].ty = ArgType::Scalar(tensor.elem_type); node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });