Skip to content

Commit

Permalink
burn-import: Add transpose operator (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luni-4 authored Jul 27, 2023
1 parent f0a7135 commit 2148ac4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion burn-import/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] ThresholdedRelu
- [ ] Tile
- [ ] TopK
- [ ] Transpose
- [x] Transpose
- [ ] Trilu
- [ ] Unique
- [ ] Unsqueeze
Expand Down
22 changes: 22 additions & 0 deletions burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub enum UnaryNodeKind {
Relu,
Sigmoid,
LogSoftmax,
Transpose,
}

impl UnaryNodeKind {
Expand All @@ -33,6 +34,7 @@ impl UnaryNodeKind {
Self::Relu => "relu",
Self::Sigmoid => "sigmoid",
Self::LogSoftmax => "log_softmax",
Self::Transpose => "transpose",
}
}
}
Expand Down Expand Up @@ -104,6 +106,11 @@ impl UnaryNode {
let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) };
Self::new(input, output, UnaryNodeKind::LogSoftmax, Arc::new(function))
}

pub(crate) fn transpose(input: TensorType, output: TensorType) -> Self {
let function = move |input| quote! { #input.transpose() };
Self::new(input, output, UnaryNodeKind::Transpose, Arc::new(function))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -174,4 +181,19 @@ mod tests {
},
);
}

#[test]
fn test_unary_codegen_transpose() {
codegen_unary_operator::<4, _>(
UnaryNode::transpose(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
),
quote! {
let tensor2 = tensor1.transpose();

tensor2
},
);
}
}
1 change: 1 addition & 0 deletions burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub fn dim_inference(
NodeType::Slice => slice_update_outputs(node),
NodeType::MatMul => same_as_input(node),
NodeType::Sigmoid => same_as_input(node),
NodeType::Transpose => same_as_input(node),
NodeType::Concat => concat_update_outputs(node),
NodeType::Reshape => reshape_update_outputs(node),
_ => todo!(
Expand Down
8 changes: 8 additions & 0 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl ONNXGraph {
NodeType::Equal => graph.register(Self::equal_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
_ => panic!("Unsupported node conversion {}", node.node_type),
}
}
Expand Down Expand Up @@ -228,6 +229,13 @@ impl ONNXGraph {
UnaryNode::flatten(input, output, start_dim, end_dim)
}

fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();

UnaryNode::transpose(input, output)
}

fn reshape_conversion(mut node: Node) -> ReshapeNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
Expand Down

0 comments on commit 2148ac4

Please sign in to comment.