From fa50316f6ceebca5fd87552b7584d3322ce176a8 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 29 Feb 2024 23:43:10 -0800 Subject: [PATCH] Revert "Optimize KahnsTopologicalSort and PriorityNodeCompare (#19475)" This reverts commit ef0b71308c0e2395d3ea63e627515ff8e624ad45. --- onnxruntime/core/graph/graph.cc | 37 ++-------- onnxruntime/core/graph/graph_viewer.cc | 18 ++--- .../core/optimizer/noop_elimination.cc | 73 ++++++++----------- .../ort_optimizer_api_impl.cc | 2 +- 4 files changed, 45 insertions(+), 85 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 305122c56b865..902839bee04ba 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,36 +1818,16 @@ void Graph::ReverseDFSFrom(gsl::span from, } } -template -struct VisitorPriorityQueue { - using ComparatorType = std::function; - std::list list_; - const ComparatorType comparator_ = nullptr; - VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} - - void push(T node) { - list_.insert( - std::upper_bound(list_.begin(), list_.end(), node, comparator_), - node); - } - bool empty() { return list_.empty(); } - T top() { return list_.back(); } - void pop() { list_.pop_back(); } -}; - #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - InlinedVector in_degree(MaxNodeIndex(), 0); - InlinedVector topo_order; - VisitorPriorityQueue to_visit(comp); - - auto number_of_nodes = NumberOfNodes(); - topo_order.reserve(number_of_nodes); + std::unordered_map in_degree; + std::priority_queue, decltype(comp)> to_visit(comp); + std::vector topo_order; for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree[node.Index()] = input_edge_count; + in_degree.insert({node.Index(), input_edge_count}); if (input_edge_count == 0) { to_visit.push(&node); } @@ -1864,17 +1844,16 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - auto& node_in_degree = in_degree[node_it->Index()]; - node_in_degree--; + in_degree[node_it->Index()]--; - if (node_in_degree == 0) { + if (in_degree[node_it->Index()] == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (number_of_nodes != static_cast(topo_order.size())) { + if (NumberOfNodes() != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2864,7 +2843,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_.emplace(tensor.name(), tensor_added); + name_to_initial_tensor_[tensor.name()] = tensor_added; SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 119d420066a84..acf7b3a16541f 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static constexpr std::string_view shape_op("Shape"); - static constexpr std::string_view size_op("Size"); + static const std::string shape_op("Shape"); + static const std::string size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,20 +26,15 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - const bool isN1HighPri = IsHighPri(n1); - const bool isN2HighPri = IsHighPri(n2); - if (isN1HighPri != isN2HighPri) { - return isN2HighPri; + if (IsHighPri(n1) != IsHighPri(n2)) { + return IsHighPri(n2); } // nodes with lower priority value will be output first - const auto n1_priority = n1->Priority(); - const auto n2_priority = n2->Priority(); - if (n1_priority != n2_priority) { - return n1_priority > n2_priority; + if (n1->Priority() != n2->Priority()) { + return n1->Priority() > n2->Priority(); } -#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -50,7 +45,6 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } -#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index bba39b698a27a..b3c2991d54b28 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,62 +42,49 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - const auto& dims = initializer->dims(); - auto initializer_rank = dims.size(); + auto initializer_rank = initializer->dims().size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int64_t tensor_size = 1; - for (auto i : dims) { - tensor_size *= i; - } - - if (tensor_size > 1) { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + if (add_init.size() > 1) { return false; } - // handle edge case where the total size of the initializer is 0 - if (tensor_size == 0) { + if (add_init.size() == 0) { return true; } - if (op_type == "Add" || - op_type == "Sub" || - op_type == "Mul" || - op_type == "Div") { - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: - return false; - } - - if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: return false; - } + } - if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { - return false; - } + if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { + return false; + } + + if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { + return false; } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index c532f56b3d3d9..d9f08ffe1171e 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.emplace(output->Name()); + graph_outputs_.insert(output->Name()); } }