Skip to content

Commit

Permalink
Cleaned up commit of the InferTypes change. (onnx#1753)
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Messer <[email protected]>

Co-authored-by: Soren Lassen <[email protected]>
  • Loading branch information
messerb5467 and sorenlassen authored Sep 30, 2022
1 parent 562adae commit d08c5ac
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,15 +960,15 @@ class FrontendGenImpl {

void InferTypes(const onnx::FunctionProto *func,
std::vector<onnx::TypeProto> &inputTypes) {
// types: Used for temporary copies of Types, freed at end of function.
std::vector<std::unique_ptr<onnx::TypeProto>> types;
std::unordered_map<std::string, onnx::TypeProto *> typeMap;
// Initialize types and values (if available) of function inputs:
const auto num_inputs =
std::min(func->input_size(), static_cast<int>(inputTypes.size()));
for (int i = 0; i < num_inputs; ++i) {
const std::string &input_name = func->input(i);
typeMap[input_name] = &inputTypes[i];
onnx_type_map.AddMapping(input_name, inputTypes[i]);
typeMap[input_name] = const_cast<onnx::TypeProto *>(
onnx_type_map.GetByOnnxName(input_name));
}

for (const onnx::NodeProto &n : func->node()) {
Expand All @@ -983,16 +983,12 @@ class FrontendGenImpl {

// Update types:
for (int i = 0; i < n.output_size(); ++i) {
std::unique_ptr<onnx::TypeProto> p =
std::make_unique<onnx::TypeProto>(*node_ctx.getOutputType(i));
typeMap[n.output(i)] = p.get();
types.push_back(std::move(p));
const std::string &output_name = n.output(i);
onnx_type_map.AddMapping(output_name, *node_ctx.getOutputType(i));
typeMap[output_name] = const_cast<onnx::TypeProto *>(
onnx_type_map.GetByOnnxName(output_name));
}
}

for (auto pair : typeMap) {
onnx_type_map.AddMapping(pair.first, *pair.second);
}
}

bool TryImportFunctionCallNode(const onnx::NodeProto &node) {
Expand Down

0 comments on commit d08c5ac

Please sign in to comment.