diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 8929fe103d..3f6ec0be86 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -65,7 +65,6 @@ namespace detail { using ValueSymbolMapping = SymbolMapping; using SymbolToOnnxTypeMapping = SymbolMapping; -using InitializedTensorMapping = SymbolMapping; class FrontendGenImpl { public: @@ -121,11 +120,6 @@ class FrontendGenImpl { // onnxop: the top version in third_part/onnx std::map op_dialect_top_version_map_; - /*! - * The list of tensors initialized by the ONNX model. - */ - InitializedTensorMapping initializedTensors; - // mapping between string name and symbol ValueSymbolMapping frontend_symbols_; @@ -330,18 +324,6 @@ class FrontendGenImpl { return llvm::Optional(); } - /*! - * Import a input tensor symbol by recording a new entry in frontend_symbols_ - * recording the mapping between legalized onnx tensor name and Value - * for further lookup in computation node importing. - * @param input onnx input tensor ValueInfoProto. - * @param symbol mlir input argument. - */ - void ImportInputTensorSymbol( - const onnx::ValueInfoProto &input, Value symbol) { - BindOnnxName(input.name(), symbol); - } - NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute( onnx::AttributeProto attr) { Attribute mlirAttr; @@ -424,14 +406,14 @@ class FrontendGenImpl { FunctionType importGraph(const onnx::GraphProto &graph, Region ®ion, Operation *op, bool useStdReturn) { frontend_symbols_.pushScope(graph.name()); - initializedTensors.pushScope(graph.name()); onnx_type_map.pushScope(graph.name()); Block *entryBlock = ®ion.back(); // Maintain a mapping between the parameter and its initializer. + std::unordered_set initializerNames; for (const auto &initializer : graph.initializer()) { - const auto &initializerName = initializer.name(); - initializedTensors.AddMapping(initializerName, initializer); + BindOnnxName(initializer.name(), ImportTensor(initializer)); + initializerNames.insert(initializer.name()); } // create a function for the graph @@ -447,7 +429,7 @@ class FrontendGenImpl { int numInputs = 0; for (const auto &input : graph.input()) { AddValueInfo(input); - if (!initializedTensors.ContainsKey(input.name())) { + if (initializerNames.count(input.name()) == 0) { inputNames.push_back(input.name()); auto argTy = ImportType(input.type()); auto shapedTy = argTy.dyn_cast(); @@ -507,10 +489,10 @@ class FrontendGenImpl { // Counter of un-initialized tensors. This counter is used to index the // entry block arguments. int entryBlockArgIdx = 0; - for (const onnx::ValueInfoProto &inputProto : graph.input()) { - if (!initializedTensors.ContainsKey(inputProto.name())) { - ImportInputTensorSymbol( - inputProto, entryBlock->getArguments()[entryBlockArgIdx]); + for (const auto &input : graph.input()) { + if (initializerNames.count(input.name()) == 0) { + BindOnnxName( + input.name(), entryBlock->getArguments()[entryBlockArgIdx]); entryBlockArgIdx++; } } @@ -537,7 +519,6 @@ class FrontendGenImpl { op->setAttr("output_names", builder_.getStrArrayAttr(outputNames)); frontend_symbols_.popScope(graph.name()); - initializedTensors.popScope(graph.name()); onnx_type_map.popScope(graph.name()); return builder_.getFunctionType(argTypes, retTys); } @@ -723,11 +704,7 @@ class FrontendGenImpl { if (item.empty()) { inputs.emplace_back(none()); } else { - if (const onnx::TensorProto *tensorPtr = - initializedTensors.GetByOnnxName(item)) { - inputs.push_back(ImportTensor(*tensorPtr)); - } else if (const Value *valuePtr = - frontend_symbols_.GetByOnnxName(item)) { + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) { inputs.push_back(*valuePtr); } } @@ -797,11 +774,7 @@ class FrontendGenImpl { // Optional inputs using empty string will be imported as NoneType. inputs.emplace_back(none()); } else { - if (const onnx::TensorProto *tensorPtr = - initializedTensors.GetByOnnxName(item)) { - inputs.push_back(ImportTensor(*tensorPtr)); - } else if (const Value *valuePtr = - frontend_symbols_.GetByOnnxName(item)) { + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) { inputs.push_back(*valuePtr); } } @@ -852,13 +825,8 @@ class FrontendGenImpl { // Copy the provided inputs first std::vector inputs; for (const auto &item : node.input()) { - if (const onnx::TensorProto *tensorPtr = - initializedTensors.GetByOnnxName(item)) { - inputs.push_back(ImportTensor(*tensorPtr)); - } else { - if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) { - inputs.push_back(*valuePtr); - } + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) { + inputs.push_back(*valuePtr); } } @@ -943,16 +911,11 @@ class FrontendGenImpl { }; for (const auto &item : llvm::enumerate(node.input())) { - if (const onnx::TensorProto *tensorPtr = - initializedTensors.GetByOnnxName(item.value())) { - inVals[item.index()] = ImportTensor(*tensorPtr); + if (const Value *valuePtr = + frontend_symbols_.GetByOnnxName(item.value())) { + inVals[item.index()] = *valuePtr; } else { - if (const Value *valuePtr = - frontend_symbols_.GetByOnnxName(item.value())) { - inVals[item.index()] = *valuePtr; - } else { - assert(false && "Unknown input"); - } + assert(false && "Unknown input"); } } @@ -1006,11 +969,7 @@ class FrontendGenImpl { // Copy the provided inputs first. std::vector inputs; for (const auto &item : node.input()) { - if (const onnx::TensorProto *tensorPtr = - initializedTensors.GetByOnnxName(item)) { - inputs.push_back(ImportTensor(*tensorPtr)); - } else if (const Value *valuePtr = - frontend_symbols_.GetByOnnxName(item)) { + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) { inputs.push_back(*valuePtr); } }