Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fuse-collapse
Browse files Browse the repository at this point in the history
  • Loading branch information
chentong319 committed Nov 13, 2023
2 parents 96efa12 + b618e71 commit c6b9f02
Show file tree
Hide file tree
Showing 46 changed files with 500 additions and 224 deletions.
6 changes: 3 additions & 3 deletions docs/LocationInfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ Then location info can be found in the output of test_add.onnx.onnx.mlir
The test_add.onnx.mlir content:

```
1 module attributes {llvm.data_layout = "e-m:o-p270:32:32-p271:32:32-p 272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x8 6_64-apple-darwin22.3.0", "onnx-mlir.symbol-postfix" = "test_add"} {
2 func.func @main_graph(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x 5xf32>) -> tensor<3x4x5xf32> attributes {input_names = ["x", "y"], o utput_names = ["sum"]} {
3 %0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x4x5 xf32>) -> tensor<3x4x5xf32>
1 module attributes {llvm.data_layout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-apple-darwin22.3.0", "onnx-mlir.symbol-postfix" = "test_add"} {
2 func.func @main_graph(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
3 %0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
4 onnx.Return %0 : tensor<3x4x5xf32>
5 }
6 "onnx.EntryPoint"() {func = @main_graph} : () -> ()
Expand Down
2 changes: 1 addition & 1 deletion docs/Testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ $gdb Debug/bin/run-onnx-lib
(gdb) run ./test_add.so
(gdb) list
1 builtin.module {
2 builtin.func @main_graph(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> attributes {input_names = ["x", "y"], output_names = ["sum"]} {
2 builtin.func @main_graph(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
3 %0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
4 return %0 : tensor<3x4x5xf32>
5 }
Expand Down
126 changes: 111 additions & 15 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class FrontendGenImpl {
// mapping between string name and symbol
ValueSymbolMapping frontend_symbols_;

// Keep shape information set by users.
ModelInputShaper modelInputShaper_;

using ImportHandlerType = void (onnx_mlir::detail::FrontendGenImpl::*)(
Expand Down Expand Up @@ -286,8 +287,10 @@ class FrontendGenImpl {
/*!
* Import an onnx tensor type by determining and returning its type
* @param type_proto onnx tensor TypeProto.
* @param dim_params a comma-separated string of dimIndex:dimParam.
*/
Type ImportTensorType(const onnx::TypeProto &type_proto) {
Type ImportTensorType(
const onnx::TypeProto &type_proto, std::string *dim_params = nullptr) {
assert(type_proto.value_case() == onnx::TypeProto::kTensorType &&
"expect tensor type");
std::vector<int64_t> dims;
Expand All @@ -300,6 +303,7 @@ class FrontendGenImpl {
auto shape_proto = tensor_type.shape();
for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) {
// Dim is a constant value.
int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert(dim_numeric_size != 0 &&
"Parsed an tensor with a dimension size of zero");
Expand All @@ -309,6 +313,14 @@ class FrontendGenImpl {
// If dim_value < 0, then dim is parametric.
dims.push_back(ShapedType::kDynamic);
}
} else if (dim_params && shape_proto.dim()[i].has_dim_param()) {
// Dim is unknown but assigned a string ID that can be used to check
// equality between unknown dimensions.
if (!dim_params->empty())
*dim_params += ",";
*dim_params +=
std::to_string(i) + ":" + shape_proto.dim()[i].dim_param();
dims.push_back(ShapedType::kDynamic);
} else {
dims.push_back(ShapedType::kDynamic);
}
Expand All @@ -318,13 +330,14 @@ class FrontendGenImpl {
return RankedTensorType::get(tensor_dims, elementType);
}

Type ImportSequenceType(const onnx::TypeProto &type_proto) {
Type ImportSequenceType(
const onnx::TypeProto &type_proto, std::string *dim_params = nullptr) {
auto input_seq_type = type_proto.sequence_type();
if (input_seq_type.has_elem_type()) {
onnx::TypeProto elem_type = input_seq_type.elem_type();
assert(elem_type.value_case() == onnx::TypeProto::kTensorType &&
"expect tensor inside sequence type");
Type mlir_elem_type = ImportTensorType(elem_type);
Type mlir_elem_type = ImportTensorType(elem_type, dim_params);
if (!mlir_elem_type.isa<ShapedType>())
llvm_unreachable("Seq type is incorrect");
Type seq_type = mlir::SeqType::get(mlir_elem_type.cast<ShapedType>(), -1);
Expand All @@ -343,13 +356,14 @@ class FrontendGenImpl {
llvm_unreachable("unexpected type");
}

Type ImportType(const onnx::TypeProto &type_proto) {
Type ImportType(
const onnx::TypeProto &type_proto, std::string *dim_params = nullptr) {
switch (type_proto.value_case()) {
case onnx::TypeProto::kTensorType:
return ImportTensorType(type_proto);
return ImportTensorType(type_proto, dim_params);
break;
case onnx::TypeProto::kSequenceType:
return ImportSequenceType(type_proto);
return ImportSequenceType(type_proto, dim_params);
break;
case onnx::TypeProto::kOptionalType:
return ImportOptionalType(type_proto);
Expand Down Expand Up @@ -468,17 +482,26 @@ class FrontendGenImpl {
// * maintain a list of the defined graph
llvm::SmallVector<Type, 4> argTypes;

llvm::SmallVector<llvm::StringRef, 4> inputNames;
llvm::SmallVector<llvm::StringRef, 4> outputNames;
llvm::SmallVector<llvm::StringRef, 4> inputNames, outputNames;
// Keep dim_param for each dynamic dimension of each input tensor.
// In ONNX specification, two dynamic dimensions with the same dim_param
// string would be the same at runtime.
//
// See https://github.com/onnx/onnx/blob/main/docs/IR.md for more
// information about dim_param.
llvm::SmallVector<std::string, 4> inputDimParams, outputDimParams;

// Import the input tensor types that are not constant and not initialized.
int inputIndex = 0;
for (const auto &input : graph.input()) {
AddValueInfo(input);
if (initializerNames.count(input.name()) == 0) {
inputNames.push_back(input.name());
Type argTy = ImportType(input.type());
std::string dimParams = "";
Type argTy = ImportType(input.type(), &dimParams);
argTy = modelInputShaper_.reshape(inputIndex, argTy);
if (!dimParams.empty())
inputDimParams.emplace_back(dimParams);

argTypes.emplace_back(argTy);

Expand Down Expand Up @@ -524,7 +547,10 @@ class FrontendGenImpl {
llvm::SmallVector<Value, 4> retVals;
// Import the output tensors
for (const auto &output : graph.output()) {
ImportOutputTensor(output, retTys, retVals);
std::string dimParams = "";
ImportOutputTensor(output, retTys, retVals, &dimParams);
if (!dimParams.empty())
outputDimParams.emplace_back(dimParams);
}

if (useReturn)
Expand All @@ -533,8 +559,21 @@ class FrontendGenImpl {
// Create a return operation to return all ONNX output tensors.
builder_.create<ONNXYieldOp>(UnknownLoc(), retVals);

op->setAttr("input_names", builder_.getStrArrayAttr(inputNames));
op->setAttr("output_names", builder_.getStrArrayAttr(outputNames));
SmallVector<llvm::StringRef> inputDimParamsRefs, outputDimParamsRefs;
for (uint64_t i = 0; i < inputDimParams.size(); ++i)
inputDimParamsRefs.emplace_back(llvm::StringRef(inputDimParams[i]));
for (uint64_t i = 0; i < outputDimParams.size(); ++i)
outputDimParamsRefs.emplace_back(llvm::StringRef(outputDimParams[i]));
if (!inputNames.empty())
op->setAttr("input_names", builder_.getStrArrayAttr(inputNames));
if (!outputNames.empty())
op->setAttr("output_names", builder_.getStrArrayAttr(outputNames));
if (!inputDimParamsRefs.empty())
op->setAttr(
"input_dim_params", builder_.getStrArrayAttr(inputDimParamsRefs));
if (!outputDimParamsRefs.empty())
op->setAttr(
"output_dim_params", builder_.getStrArrayAttr(outputDimParamsRefs));

frontend_symbols_.popScope(graph.name());
onnx_type_map.popScope(graph.name());
Expand Down Expand Up @@ -1328,23 +1367,73 @@ class FrontendGenImpl {
* @param output onnx output ValueInfoProto.
* @param ret_types a vector of types representing graph's output types.
* @param ret_vals a vector of mlir Value representing graph's output.
* @param dim_params a comma-separated string of dimIndex:dimParam.
*/
void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<Type> &ret_types,
llvm::SmallVectorImpl<Value> &ret_vals) {
llvm::SmallVectorImpl<Value> &ret_vals,
std::string *dim_params = nullptr) {
const Value *valPtr = frontend_symbols_.GetByOnnxName(output.name());
Value val = *valPtr;
if (output.type().value_case() == onnx::TypeProto::kTensorType) {
if (output.type().tensor_type().has_shape()) {
val.setType(ImportType(output.type()));
val.setType(ImportType(output.type(), dim_params));
}
ret_types.emplace_back(val.getType());
} else {
ret_types.emplace_back(ImportType(output.type()));
ret_types.emplace_back(ImportType(output.type(), dim_params));
}
ret_vals.push_back(val);
}

// Move function attributes for argument/result names and dim_params into
// argument/result attributes.
void moveFuncAttrsToArgAttrs(func::FuncOp funcOp,
ArrayRef<std::string> funcAttrNames, ArrayRef<std::string> argAttrNames,
bool isArg) {
assert(funcAttrNames.size() == argAttrNames.size() &&
"The number of attributes to move mismatched");
Operation *op = funcOp.getOperation();
size_t numOfArgs =
(isArg) ? funcOp.getNumArguments() : funcOp.getNumResults();

// Only move attributes that exists.
SmallVector<ArrayAttr, 2> funcAttrsToMove;
SmallVector<std::string, 2> targetArgAttrNames;
for (size_t i = 0; i < funcAttrNames.size(); ++i) {
ArrayAttr attr = op->getAttrOfType<ArrayAttr>(funcAttrNames[i]);
if (!attr)
continue;
funcAttrsToMove.emplace_back(attr);
targetArgAttrNames.emplace_back(argAttrNames[i]);
}

// Move function attributes to argument/result attributes.
for (size_t i = 0; i < numOfArgs; ++i) {
SmallVector<NamedAttribute, 2> argAttrs;
for (size_t k = 0; k < funcAttrsToMove.size(); ++k) {
if (i < funcAttrsToMove[k].size()) {
auto name = (funcAttrsToMove[k].getValue()[i]).cast<StringAttr>();
if (name) {
NamedAttribute namedAttr =
builder_.getNamedAttr(argAttrNames[k], name);
argAttrs.emplace_back(namedAttr);
}
}
}
if (!argAttrs.empty()) {
if (isArg)
funcOp.setArgAttrs(i, argAttrs);
else
funcOp.setResultAttrs(i, argAttrs);
}
}

// Clean up the function attributes.
for (std::string s : funcAttrNames)
op->removeAttr(s);
}

/*!
* Import ONNX main computation graph.
* @param graph onnx graph proto.
Expand All @@ -1363,6 +1452,13 @@ class FrontendGenImpl {
/*op=*/mainFunc.getOperation(), /*useReturn=*/true);
mainFunc.setType(funcType);

// Move function attributes for argument/result names and dim_params into
// argument/result attributes.
moveFuncAttrsToArgAttrs(mainFunc, {"input_names", "input_dim_params"},
{"onnx.name", "onnx.dim_params"}, /*isArg=*/true);
moveFuncAttrsToArgAttrs(mainFunc, {"output_names", "output_dim_params"},
{"onnx.name", "onnx.dim_params"}, /*isArg=*/false);

// Emit entry point op describing inference function signature.
auto entryPoint = ONNXEntryPointOp::create(UnknownLoc(), mainFunc);
module_.push_back(entryPoint);
Expand Down
52 changes: 23 additions & 29 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,41 +118,26 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
parsingFailure = false;
auto inputs = funcType.getInputs();
auto outputs = funcType.getResults();

ArrayAttr inputNames = op->getAttrOfType<ArrayAttr>("input_names");
if (!inputNames) {
SmallVector<StringRef, 4> names;
for (uint64_t i = 0; i < inputs.size(); ++i)
names.emplace_back(StringRef("input_" + std::to_string(i)));
inputNames = b.getStrArrayAttr(names);
} else if (inputNames.size() != inputs.size()) {
llvm::errs()
<< "Please ensure that the 'input_name' function attribute has "
"the same number of names as function parameters.";
parsingFailure = true;
return "";
}
ArrayAttr outputNames = op->getAttrOfType<ArrayAttr>("output_names");
if (!outputNames) {
SmallVector<StringRef, 4> names;
for (uint64_t i = 0; i < outputs.size(); ++i)
names.emplace_back(StringRef("output_" + std::to_string(i)));
outputNames = b.getStrArrayAttr(names);
} else if (outputNames.size() != outputs.size()) {
llvm::errs()
<< "Please ensure that the 'output_name' function attribute has "
"the same number of names as function results.";
parsingFailure = true;
return "";
}
auto funcOp = dyn_cast_or_null<func::FuncOp>(op);
ArrayAttr argAttrs = funcOp.getArgAttrsAttr();
ArrayAttr resAttrs = funcOp.getResAttrsAttr();

std::string dString;
llvm::raw_string_ostream dstream(dString);
dstream << "[ ";
std::string comma = std::string("");
for (unsigned int i = 0; i < funcType.getNumInputs(); i++) {
dstream << comma;
concatTypeString(inputs[i], inputNames[i], dstream);
StringAttr inputName = b.getStringAttr({"input_" + std::to_string(i)});
if (argAttrs) {
DictionaryAttr dictAttrs = llvm::dyn_cast<DictionaryAttr>(argAttrs[i]);
if (dictAttrs && dictAttrs.contains("onnx.name"))
inputName = dictAttrs.getNamed("onnx.name")
.value()
.getValue()
.cast<StringAttr>();
}
concatTypeString(inputs[i], inputName, dstream);
comma = std::string(" , ");
}
dstream << "\n]";
Expand All @@ -162,7 +147,16 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
comma = std::string("");
for (unsigned int i = 0; i < funcType.getNumResults(); i++) {
dstream << comma;
concatTypeString(outputs[i], outputNames[i], dstream);
StringAttr outputName = b.getStringAttr({"output_" + std::to_string(i)});
if (argAttrs) {
DictionaryAttr dictAttrs = llvm::dyn_cast<DictionaryAttr>(resAttrs[i]);
if (dictAttrs && dictAttrs.contains("onnx.name"))
outputName = dictAttrs.getNamed("onnx.name")
.value()
.getValue()
.cast<StringAttr>();
}
concatTypeString(outputs[i], outputName, dstream);
comma = std::string(" , ");
}
dstream << "\n]";
Expand Down
19 changes: 19 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ Value OnnxBuilder::constantInt64(const ArrayRef<int64_t> intVals) const {
return constant(denseAttr);
}

Value OnnxBuilder::conv(Type Y, Value X, Value W, Value B, StringRef autoPad,
ArrayRef<int64_t> dilations, int64_t group, ArrayRef<int64_t> kernelShape,
ArrayRef<int64_t> pads, ArrayRef<int64_t> strides) const {
StringAttr autoPadAttr = b().getStringAttr(autoPad);
ArrayAttr dilationsAttr = b().getI64ArrayAttr(dilations);
IntegerAttr groupAttr =
IntegerAttr::get(b().getIntegerType(64, /*isSigned=*/true),
APInt(64, group, /*isSigned=*/true));
ArrayAttr kernelShapeAttr = b().getI64ArrayAttr(kernelShape);
ArrayAttr padsAttr = b().getI64ArrayAttr(pads);
ArrayAttr stridesAttr = b().getI64ArrayAttr(strides);
return createOpAndInferShapes<ONNXConvOp>(toTensor(Y), X, W, B, autoPadAttr,
dilationsAttr, groupAttr, kernelShapeAttr, padsAttr, stridesAttr);
}

Value OnnxBuilder::dim(Value input, int axis) const {
Type resultType = RankedTensorType::get({1}, b().getI64Type());
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Expand Down Expand Up @@ -322,6 +337,10 @@ Value OnnxBuilder::sub(Value A, Value B) const {
return createOpAndInferShapes<ONNXSubOp>(toTensor(A), toTensor(B));
}

Value OnnxBuilder::sum(Type outputType, ValueRange inputs) const {
return createTypedOpAndInferShapes<ONNXSumOp>(toTensor(outputType), inputs);
}

Value OnnxBuilder::transpose(
Type outputType, Value input, ArrayAttr perm) const {
return createTypedOpAndInferShapes<ONNXTransposeOp>(
Expand Down
9 changes: 9 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value constant(mlir::Attribute denseAttr) const;
mlir::Value constantInt64(const mlir::ArrayRef<int64_t> intVals) const;

// ONNXConvOp
mlir::Value conv(mlir::Type Y, mlir::Value X, mlir::Value W, mlir::Value B,
llvm::StringRef autoPad, mlir::ArrayRef<int64_t> dilations, int64_t group,
mlir::ArrayRef<int64_t> kernelShape, mlir::ArrayRef<int64_t> pads,
mlir::ArrayRef<int64_t> strides) const;

// ONNXDivOp
mlir::Value div(mlir::Value A, mlir::Value B) const;

Expand Down Expand Up @@ -174,6 +180,9 @@ struct OnnxBuilder : DialectBuilder {
// ONNXSubOp
mlir::Value sub(mlir::Value A, mlir::Value B) const;

// ONNXSumOp
mlir::Value sum(mlir::Type outputType, mlir::ValueRange inputs) const;

// UnrealizedConversionCastOp
// Convert a Value to TensorType if it is of MemRefType.
mlir::Value toTensor(mlir::Value input) const;
Expand Down
Loading

0 comments on commit c6b9f02

Please sign in to comment.