Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove initializedTensors map from ONNX parser #1739

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 17 additions & 58 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ namespace detail {

using ValueSymbolMapping = SymbolMapping<Value>;
using SymbolToOnnxTypeMapping = SymbolMapping<onnx::TypeProto>;
using InitializedTensorMapping = SymbolMapping<onnx::TensorProto>;

class FrontendGenImpl {
public:
Expand Down Expand Up @@ -121,11 +120,6 @@ class FrontendGenImpl {
// onnxop: the top version in third_part/onnx
std::map<std::string, int> op_dialect_top_version_map_;

/*!
* The list of tensors initialized by the ONNX model.
*/
InitializedTensorMapping initializedTensors;

// mapping between string name and symbol
ValueSymbolMapping frontend_symbols_;

Expand Down Expand Up @@ -330,18 +324,6 @@ class FrontendGenImpl {
return llvm::Optional<Type>();
}

/*!
* Import a input tensor symbol by recording a new entry in frontend_symbols_
* recording the mapping between legalized onnx tensor name and Value
* for further lookup in computation node importing.
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(
const onnx::ValueInfoProto &input, Value symbol) {
BindOnnxName(input.name(), symbol);
}

NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
onnx::AttributeProto attr) {
Attribute mlirAttr;
Expand Down Expand Up @@ -424,14 +406,14 @@ class FrontendGenImpl {
FunctionType importGraph(const onnx::GraphProto &graph, Region &region,
Operation *op, bool useStdReturn) {
frontend_symbols_.pushScope(graph.name());
initializedTensors.pushScope(graph.name());
onnx_type_map.pushScope(graph.name());
Block *entryBlock = &region.back();

// Maintain a mapping between the parameter and its initializer.
std::unordered_set<std::string> initializerNames;
for (const auto &initializer : graph.initializer()) {
const auto &initializerName = initializer.name();
initializedTensors.AddMapping(initializerName, initializer);
BindOnnxName(initializer.name(), ImportTensor(initializer));
initializerNames.insert(initializer.name());
}

// create a function for the graph
Expand All @@ -447,7 +429,7 @@ class FrontendGenImpl {
int numInputs = 0;
for (const auto &input : graph.input()) {
AddValueInfo(input);
if (!initializedTensors.ContainsKey(input.name())) {
if (initializerNames.count(input.name()) == 0) {
inputNames.push_back(input.name());
auto argTy = ImportType(input.type());
auto shapedTy = argTy.dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -507,10 +489,10 @@ class FrontendGenImpl {
// Counter of un-initialized tensors. This counter is used to index the
// entry block arguments.
int entryBlockArgIdx = 0;
for (const onnx::ValueInfoProto &inputProto : graph.input()) {
if (!initializedTensors.ContainsKey(inputProto.name())) {
ImportInputTensorSymbol(
inputProto, entryBlock->getArguments()[entryBlockArgIdx]);
for (const auto &input : graph.input()) {
if (initializerNames.count(input.name()) == 0) {
BindOnnxName(
input.name(), entryBlock->getArguments()[entryBlockArgIdx]);
entryBlockArgIdx++;
}
}
Expand All @@ -537,7 +519,6 @@ class FrontendGenImpl {
op->setAttr("output_names", builder_.getStrArrayAttr(outputNames));

frontend_symbols_.popScope(graph.name());
initializedTensors.popScope(graph.name());
onnx_type_map.popScope(graph.name());
return builder_.getFunctionType(argTypes, retTys);
}
Expand Down Expand Up @@ -723,11 +704,7 @@ class FrontendGenImpl {
if (item.empty()) {
inputs.emplace_back(none());
} else {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down Expand Up @@ -797,11 +774,7 @@ class FrontendGenImpl {
// Optional inputs using empty string will be imported as NoneType.
inputs.emplace_back(none());
} else {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down Expand Up @@ -852,13 +825,8 @@ class FrontendGenImpl {
// Copy the provided inputs first
std::vector<Value> inputs;
for (const auto &item : node.input()) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}

Expand Down Expand Up @@ -943,16 +911,11 @@ class FrontendGenImpl {
};

for (const auto &item : llvm::enumerate(node.input())) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item.value())) {
inVals[item.index()] = ImportTensor(*tensorPtr);
if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item.value())) {
inVals[item.index()] = *valuePtr;
} else {
if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item.value())) {
inVals[item.index()] = *valuePtr;
} else {
assert(false && "Unknown input");
}
assert(false && "Unknown input");
}
}

Expand Down Expand Up @@ -1006,11 +969,7 @@ class FrontendGenImpl {
// Copy the provided inputs first.
std::vector<Value> inputs;
for (const auto &item : node.input()) {
if (const onnx::TensorProto *tensorPtr =
initializedTensors.GetByOnnxName(item)) {
inputs.push_back(ImportTensor(*tensorPtr));
} else if (const Value *valuePtr =
frontend_symbols_.GetByOnnxName(item)) {
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(item)) {
inputs.push_back(*valuePtr);
}
}
Expand Down