Skip to content

Commit

Permalink
[DialectBuilder] add builder funcrions for ONNXSumOp and ONNXConvOp (o…
Browse files Browse the repository at this point in the history
…nnx#2572)

The DialectBuilder class seems to be missing the function create the
ONNXSumOp and ONNXConOp nodes and check their shape.  This patch adds
the necessary functions.

Signed-off-by: Ashay Rane <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
ashay and AlexandreEichenberger authored Nov 13, 2023
1 parent 04e26e7 commit b618e71
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ Value OnnxBuilder::constantInt64(const ArrayRef<int64_t> intVals) const {
return constant(denseAttr);
}

Value OnnxBuilder::conv(Type Y, Value X, Value W, Value B, StringRef autoPad,
ArrayRef<int64_t> dilations, int64_t group, ArrayRef<int64_t> kernelShape,
ArrayRef<int64_t> pads, ArrayRef<int64_t> strides) const {
StringAttr autoPadAttr = b().getStringAttr(autoPad);
ArrayAttr dilationsAttr = b().getI64ArrayAttr(dilations);
IntegerAttr groupAttr =
IntegerAttr::get(b().getIntegerType(64, /*isSigned=*/true),
APInt(64, group, /*isSigned=*/true));
ArrayAttr kernelShapeAttr = b().getI64ArrayAttr(kernelShape);
ArrayAttr padsAttr = b().getI64ArrayAttr(pads);
ArrayAttr stridesAttr = b().getI64ArrayAttr(strides);
return createOpAndInferShapes<ONNXConvOp>(toTensor(Y), X, W, B, autoPadAttr,
dilationsAttr, groupAttr, kernelShapeAttr, padsAttr, stridesAttr);
}

Value OnnxBuilder::dim(Value input, int axis) const {
Type resultType = RankedTensorType::get({1}, b().getI64Type());
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Expand Down Expand Up @@ -322,6 +337,10 @@ Value OnnxBuilder::sub(Value A, Value B) const {
return createOpAndInferShapes<ONNXSubOp>(toTensor(A), toTensor(B));
}

Value OnnxBuilder::sum(Type outputType, ValueRange inputs) const {
return createTypedOpAndInferShapes<ONNXSumOp>(toTensor(outputType), inputs);
}

Value OnnxBuilder::transpose(
Type outputType, Value input, ArrayAttr perm) const {
return createTypedOpAndInferShapes<ONNXTransposeOp>(
Expand Down
9 changes: 9 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value constant(mlir::Attribute denseAttr) const;
mlir::Value constantInt64(const mlir::ArrayRef<int64_t> intVals) const;

// ONNXConvOp
mlir::Value conv(mlir::Type Y, mlir::Value X, mlir::Value W, mlir::Value B,
llvm::StringRef autoPad, mlir::ArrayRef<int64_t> dilations, int64_t group,
mlir::ArrayRef<int64_t> kernelShape, mlir::ArrayRef<int64_t> pads,
mlir::ArrayRef<int64_t> strides) const;

// ONNXDivOp
mlir::Value div(mlir::Value A, mlir::Value B) const;

Expand Down Expand Up @@ -174,6 +180,9 @@ struct OnnxBuilder : DialectBuilder {
// ONNXSubOp
mlir::Value sub(mlir::Value A, mlir::Value B) const;

// ONNXSumOp
mlir::Value sum(mlir::Type outputType, mlir::ValueRange inputs) const;

// UnrealizedConversionCastOp
// Convert a Value to TensorType if it is of MemRefType.
mlir::Value toTensor(mlir::Value input) const;
Expand Down

0 comments on commit b618e71

Please sign in to comment.