diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..7e59aad80cc47 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,7 +55,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; -// For Priority based graph topology sorting. -constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; - } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200f..c5fb214bb4be9 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -294,17 +294,21 @@ class Node { Class to provide const access to Node instances iterated via an EdgeConstIterator. */ class NodeConstIterator { public: - NodeConstIterator(EdgeConstIterator p_iter); + NodeConstIterator(EdgeConstIterator p_iter) { m_iter = p_iter; } - bool operator==(const NodeConstIterator& p_other) const; + bool operator==(const NodeConstIterator& p_other) const { + return m_iter == p_other.m_iter; + } - bool operator!=(const NodeConstIterator& p_other) const; + bool operator!=(const NodeConstIterator& p_other) const { + return m_iter != p_other.m_iter; + } - void operator++(); - void operator--(); + void operator++() { ++m_iter; } + void operator--() { --m_iter; } - const Node& operator*() const; - const Node* operator->() const; + const Node& operator*() const { return (*m_iter).GetNode(); } + const Node* operator->() const { return &(operator*()); }; private: EdgeConstIterator m_iter; @@ -394,6 +398,12 @@ class Node { /** Gets the Node's attributes. */ const NodeAttributes& GetAttributes() const noexcept { return attributes_; } + /** @returns true if the Node is a forward node (inference), false (training backward pass) otherwise. **/ + bool IsForwardNode() const noexcept { return is_forward_node_; } + + /* Sets the forward node status */ + void SetForwardNode(bool is_forward_node) noexcept { is_forward_node_ = is_forward_node; } + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ bool ClearAttribute(const std::string& attr_name); @@ -626,6 +636,9 @@ class Node { // Execution priority, lower value for higher priority int priority_ = 0; + // This node is a forward node if value, otherwise it is a backward node. + bool is_forward_node_ = true; + // set from op_->SinceVersion() or via deserialization when OpSchema is not available int since_version_ = -1; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f71b7ecebcf1a..daf6593598137 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -528,34 +528,6 @@ Node::EdgeEnd::EdgeEnd(const Node& node) noexcept : EdgeEnd(node, INT_MAX, INT_MAX) { } -Node::NodeConstIterator::NodeConstIterator(EdgeConstIterator p_iter) { - m_iter = p_iter; -} - -bool Node::NodeConstIterator::operator==(const NodeConstIterator& p_other) const { - return m_iter == p_other.m_iter; -} - -bool Node::NodeConstIterator::operator!=(const NodeConstIterator& p_other) const { - return m_iter != p_other.m_iter; -} - -void Node::NodeConstIterator::operator++() { - ++m_iter; -} - -void Node::NodeConstIterator::operator--() { - --m_iter; -} - -const Node& Node::NodeConstIterator::operator*() const { - return (*m_iter).GetNode(); -} - -const Node* Node::NodeConstIterator::operator->() const { - return &(operator*()); -} - void Node::SetPriority(int priority) noexcept { priority_ = priority; } @@ -878,6 +850,7 @@ void Node::Init(std::string_view name, gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { + is_forward_node_ = true; name_ = name; op_type_ = op_type; description_ = description; @@ -898,6 +871,7 @@ void Node::Init(std::string_view name, if (attributes) { attributes_ = *attributes; + is_forward_node_ = true; for (auto& name_to_attr : attributes_) { if (utils::HasGraph(name_to_attr.second)) { #if !defined(ORT_MINIMAL_BUILD) @@ -1821,13 +1795,13 @@ void Graph::ReverseDFSFrom(gsl::span from, #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; + InlinedVector in_degree(MaxNodeIndex(), 0); std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector topo_order; for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -2044,7 +2018,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } } - std::vector InferredOutputTypes() const { return node_output_types_; } + const std::vector& InferredOutputTypes() const noexcept { return node_output_types_; } const AttributeProto* getAttribute(const std::string& name) const override { auto& attribute_value_map = node_.GetAttributes(); @@ -2240,7 +2214,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // Number of inputs corresponding to the i-th argument. const int arg_count = node.InputArgCount()[i]; // The i-th formal parameter definition. - auto op_formal_parameter = op.inputs()[i]; + const auto& op_formal_parameter = op.inputs()[i]; // Check all actual parameters (corresponding to the k-th input) // match the formal parameter definition (i-th argument). @@ -2345,7 +2319,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const int num_formal_params = gsl::narrow_cast(op.outputs().size()); auto operand_index = std::min(i, num_formal_params - 1); - auto op_formal_parameter = op.outputs().at(operand_index); + const auto& op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; DataType existing_type = output_def->Type(); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..b21b6fb50fa4a 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -36,12 +36,8 @@ struct PriorityNodeCompare { } // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + int64_t n1_is_forward = n1->IsForwardNode(); + int64_t n2_is_forward = n2->IsForwardNode(); if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index b04d794cc9469..75f06960aee48 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -255,11 +255,7 @@ Status ProcessNode( matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); #ifdef USE_ROCM - // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_scale_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } + matmul_scale_node.SetForwardNode(node.GetForwardNode()); #endif { diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 789466778edc6..a9c66d061619a 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -407,10 +407,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ matmul_node.SetExecutionProviderType(node.GetExecutionProviderType()); #ifdef USE_ROCM // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } + malmul_node.SetForwardPass(node.getForwardPass()); #endif graph_utils::FinalizeNodeFusion(graph, matmul_node, node); diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc index decb25f565efe..44c17ebe57e91 100644 --- a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc @@ -26,7 +26,7 @@ Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (is_backward_pass) { - node.AddAttribute(std::string("__backwardpass"), static_cast(1)); + node.SetForwardNode(false); modified = true; } } diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index c0b7d4722d3e4..a9a91aa96bae8 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -25,7 +25,7 @@ class RocmKernel : public OpKernel { Status Compute(OpKernelContext* p_op_kernel_context) const override { Status s; - auto is_backward_pass = Info().GetAttrOrDefault("__backwardpass", 0); + auto is_backward_pass = !Node().IsForwardNode(); if (is_backward_pass) { BackwardPassGuard guard; s = ComputeInternal(p_op_kernel_context); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb6f1..1e533d1e7e11e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -197,16 +197,13 @@ Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { // Set the attribute to true for all backward nodes. for (auto& node : graph.Nodes()) { if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - continue; + if (node.IsForwardNode()) { + node.SetForwardNode(false); + modified = true; } - node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); - modified = true; } else { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - node.ClearAttribute(kBackwardNodeAttributeName); + if (!node.IsForwardNode()) { + node.SetForwardNode(true); modified = true; } }