Skip to content

Commit

Permalink
remove initializedTensors map from ONNX parser (#1739)
Browse files Browse the repository at this point in the history
instead insert a constant for each initializer in the frontend_symbols_
symbol map

this is better for several reasons:

1. initializedTensors were incorrectly visible to function bodies (in
   TryImportFunctionCallNode)

2. nested bindings didn't hide initializedTensors within their scope

3. it's more efficient to create a constant for each initializer once
   rather than every time an initializer is accessed

4. it's simpler to look up a single symbol mapping

Signed-off-by: Soren Lassen <[email protected]>

Signed-off-by: Soren Lassen <[email protected]>
Co-authored-by: chentong319 <[email protected]>
  • Loading branch information
sorenlassen and chentong319 authored Sep 28, 2022
1 parent 62e08f8 commit 9fcbbf3
Showing 1 changed file with 17 additions and 58 deletions.
75 changes: 17 additions & 58 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ namespace detail {

using ValueSymbolMapping = SymbolMapping<Value>;
using SymbolToOnnxTypeMapping = SymbolMapping<onnx::TypeProto>;
using InitializedTensorMapping = SymbolMapping<onnx::TensorProto>;

class FrontendGenImpl {
public:
Expand Down Expand Up @@ -121,11 +120,6 @@ class FrontendGenImpl {
// onnxop: the top version in third_part/onnx
std::map<std::string, int> op_dialect_top_version_map_;

/*!
* The list of tensors initialized by the ONNX model.
*/
InitializedTensorMapping initializedTensors;

// mapping between string name and symbol
ValueSymbolMapping frontend_symbols_;

Expand Down Expand Up @@ -330,18 +324,6 @@ class FrontendGenImpl {
return llvm::Optional<Type>();
}

/*!
* Import a input tensor symbol by recording a new entry in frontend_symbols_
* recording the mapping between legalized onnx tensor name and Value
* for further lookup in computation node importing.
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(
const onnx::ValueInfoProto &input, Value symbol) {
BindOnnxName(input.name(), symbol);
}

NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
onnx::AttributeProto attr) {
Attribute mlirAttr;
Expand Down Expand Up @@ -424,14 +406,14 @@ class FrontendGenImpl {
FunctionType importGraph(const onnx::GraphProto &graph, Region &region,
Operation *op, bool useStdReturn) {
frontend_symbols_.pushScope(graph.name());
initializedTensors.pushScope(graph.name());
onnx_type_map.pushScope(graph.name());
Block *entryBlock = &region.back();

// Maintain a mapping between the parameter and its initializer.
std::unordered_set<std::string> initializerNames;
for (const auto &initializer : graph.initializer()) {
const auto &initializerName = initializer.name();
initializedTensors.AddMapping(initializerName, initializer);
BindOnnxName(initializer.name(), ImportTensor(initializer));
initializerNames.insert(initializer.name());
}

// create a function for the graph
Expand All @@ -447,7 +429,7 @@ class FrontendGenImpl {
int numInputs = 0;
for (const auto &input : graph.input()) {
AddValueInfo(input);
if (!initializedTensors.ContainsKey(input.name())) {
if (initializerNames.count(input.name()) == 0) {
inputNames.push_back(input.name());
auto argTy = ImportType(input.type());
auto shapedTy = argTy.dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -507,10 +489,10 @@ class FrontendGenImpl {
// Counter of un-initialized tensors. This counter is used to index the
// entry block arguments.
int entryBlockArgIdx = 0;
for (const onnx::ValueInfoProto &inputProto : graph.input()) {
if (!initializedTensors.ContainsKey(inputProto.name())) {
ImportInputTensorSymbol(
inputProto, entryBlock->getArguments()[entryBlockArgIdx]);
for (const auto &input : graph.input()) {
if (initializerNames.count(input.name()) == 0) {
BindOnnxName(
input.name(), entryBlock->getArguments()[entryBlockArgIdx]);
entryBlockArgIdx++;
}
}
Expand All @@ -537,7 +519,6 @@ class FrontendGenImpl {
op->setAttr("output_names", builder_.getStrArrayAttr(outputNames));

frontend_symbols_.popScope(graph.name());
initializedTensors.popScope(graph.name());
onnx_type_map.popScope(graph.name());
return builder_.getFunctionType(argTypes, retTys);
}
Expand Down Expand Up @@ -723,11 +704,7 @@ class FrontendGenImpl {
if (item.empty()) {
inputs.emplace_back(none());
} else {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down Expand Up @@ -797,11 +774,7 @@ class FrontendGenImpl {
// Optional inputs using empty string will be imported as NoneType.
inputs.emplace_back(none());
} else {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down Expand Up @@ -852,13 +825,8 @@ class FrontendGenImpl {
// Copy the provided inputs first
std::vector<Value> inputs;
for (const auto &item : node.input()) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}

Expand Down Expand Up @@ -943,16 +911,11 @@ class FrontendGenImpl {
};

for (const auto &item : llvm::enumerate(node.input())) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item.value())) {
inVals[item.index()] = ImportTensor(*tensorPtr);
if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item.value())) {
inVals[item.index()] = *valuePtr;
} else {
if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item.value())) {
inVals[item.index()] = *valuePtr;
} else {
assert(false && "Unknown input");
}
assert(false && "Unknown input");
}
}

Expand Down Expand Up @@ -1006,11 +969,7 @@ class FrontendGenImpl {
// Copy the provided inputs first.
std::vector<Value> inputs;
for (const auto &item : node.input()) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down

0 comments on commit 9fcbbf3

Please sign in to comment.