Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce mean ONNX op support #1637

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ represent the corresponding Burn Op.
| [ReduceLogSum][133] | ❌ | ❌ |
| [ReduceLogSumExp][134] | ❌ | ❌ |
| [ReduceMax][135] | ❌ | ✅ |
| [ReduceMean][136] | | ✅ |
| [ReduceMean][136] | | ✅ |
| [ReduceMin][137] | ❌ | ✅ |
| [ReduceProd][138] | ❌ | ❌ |
| [ReduceSum][139] | ❌ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fn main() {
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/softmax/softmax.onnx")
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ include_models!(
mul,
neg,
recip,
reduce_mean,
relu,
reshape,
sigmoid,
Expand Down Expand Up @@ -442,6 +443,22 @@ mod tests {
output3.to_data().assert_approx_eq(&expected3, 3);
}

#[test]
fn reduce_mean() {
let device = Default::default();
let model: reduce_mean::Model<Backend> = reduce_mean::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([9.75]);
let expected = Data::from([[[[9.75]]]]);

assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reshape() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
Binary file not shown.
46 changes: 46 additions & 0 deletions crates/burn-import/onnx-tests/tests/reduce_mean/reduce_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/reduce_mean/reduce_mean.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return (
# ReduceMean, keepdims=0, axes=None
torch.mean(x),
# ReduceMean, keepdims=1, axes=[1]
torch.mean(x, dim=1, keepdim=True),
# ReduceMean, keepdims=1, axes=[-1]
torch.mean(x, dim=-1, keepdim=True),
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "reduce_mean.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)

torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
65 changes: 65 additions & 0 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum UnaryNodeKind {
Log,
LogSoftmax,
Neg,
ReduceMean,
Reciprocal,
LeakyRelu,
Relu,
Expand All @@ -51,6 +52,7 @@ impl UnaryNodeKind {
Self::Log => "log",
Self::LogSoftmax => "log_softmax",
Self::Neg => "neg",
Self::ReduceMean => "reduce_mean",
Self::Reciprocal => "reciprocal",
Self::LeakyRelu => "leaky_relu",
Self::Relu => "relu",
Expand Down Expand Up @@ -246,6 +248,32 @@ impl UnaryNode {
_ => panic!("output must be a tensor"),
}
}

pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self {
// ReduceMean is constrained to numeric tensors, so no need to check for bool.
if let Type::Tensor(_) = output {
if let Some(dim) = dim {
// ReduceMean, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceMean,
Rc::new(move |input| quote! { #input.mean_dim(#dim) }),
)
} else {
// ReduceMean, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceMean,
Rc::new(move |input| quote! { #input.mean() }),
)
}
} else {
panic!("ReduceMean only supports tensor output");
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -430,6 +458,43 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reduce_mean() {
one_node_graph(
UnaryNode::reduce_mean(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.mean_dim(1);

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);

one_node_graph(
UnaryNode::reduce_mean(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.mean();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_reciprocal() {
one_node_graph(
Expand Down
8 changes: 5 additions & 3 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Shape => shape_update_outputs(node),
Expand Down Expand Up @@ -205,12 +205,11 @@ fn reshape_update_outputs(node: &mut Node) {
});
}

fn mean_update_outputs(node: &mut Node) {
fn reduce_mean_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
}

// Extract the configuration of the linear layer (inputs are known)
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
Expand All @@ -229,6 +228,9 @@ fn mean_update_outputs(node: &mut Node) {
if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
// NOTE: ReduceMean w/o keepdims reduces to a scalar value, but Burn doesn't have
// 0-dim tensor so we can't track or perform other ops on that value
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}
Expand Down
44 changes: 44 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,47 @@ fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d {
panic!("Padding configuration ({:?}) not supported", pads);
}
}

pub fn reduce_mean_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;

let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => axes = value.clone().into_i64s(),
"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}

if axes.len() > 1 {
panic!("ReduceMean: reducing on multiple dimensions is not supported")
}

if axes.is_empty() && keepdims == 1 {
panic!("ReduceMean: axes must be provided with keepdims")
}

if !axes.is_empty() && keepdims == 0 {
// Not supported in Burn
panic!("ReduceMean: the reduce operation must preserve the reduced dimension")
}

if axes.is_empty() {
None
} else {
let mut dim = axes[0];

if dim < 0 {
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ impl OnnxGraph {
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
Expand Down Expand Up @@ -462,6 +463,15 @@ impl OnnxGraph {

ReshapeNode::new(input, output, shape)
}

fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_mean_config(&node);

UnaryNode::reduce_mean(input, output, dim)
}

fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
Expand Down
Loading