Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize onnxruntime::InferenceSession::Initialize with focus on GrapViewer. For large models the speedup of this function can be up to 3x. #19080

Closed

Conversation

mtavenrath
Copy link
Contributor

@mtavenrath mtavenrath commented Jan 10, 2024

Description

  • In multiple locations which are in the hot loop copies of unordered_maps and other types were made where references would have been possible
  • In multiple locations return by value is being used where return by const& would have been appropriate
  • Replaced std::unordered_map by std::vector for in_degree in KahnsTopologicalSort
  • Added Node::isForwardNode() function to cache the existence of kBackwardNodeAttributeName.
  • Inlined functions of NodeConstIterator

Motivation and Context

Load time of the huggingface 7b llama files can be quite long (>45s). 50% of this time is IO while the other 50% of the time was TransformGraph. To decrease loading time and thus startup time it'd be good to optimize the non-IO part as much as possible.

…hViewer. For large models the speedup of this function can be up to 3x.
@mtavenrath
Copy link
Contributor Author

@microsoft-github-policy-service agree company="NVIDIA"

Copy link
Contributor Author

@mtavenrath mtavenrath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gedoensmax Can you take a look at this change please?

@@ -878,6 +850,7 @@ void Node::Init(std::string_view name,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
isForwardNode_ = true;
Copy link
Contributor Author

@mtavenrath mtavenrath Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the forward node on each change is kind of ugly and risky. With regards to performance it should be unproblematic since the string compare is quite efficient doing a length check first before actually comparing bytes.

The reason for caching isForwardNode_ is each check for the kBackwardNodeAttributeName within the PriorityNodeCompare actually computes the hash of the string which is the costly part.

Alternate solutions would change the container and precompute the hash of kBackwardNodeAttributeName and incorperate it into the key. The downside with this solution is that it'd work only for cases where colisions are handled by lists instead of multiple iterations of hash keys. #Closed

int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) ||
auto const& n1_attrs = n1->GetAttributes();
auto const& n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change shaves off nearly 2s from the original 20s of TransformGraph and will be maximum efficient when loading models for inference.

For training models the hash computation is still done once. Before the hash has potentially been computed twice, one time for find and one time for at. By fetching the iterator of std::fine and using it later to get the value of i the number of hash computation could be halved (saving ~1s instead of 2s).

If it's known that the model is used for inference only it'd be great if PriorityNodeCompare could skip this test altogether. This could be achieved most efficient by making this a template class and specialize for inference / training.

@gedoensmax
Copy link
Contributor

@yufenglee can you help to find a reviewer for this ?

include/onnxruntime/core/graph/graph.h Show resolved Hide resolved
Comment on lines 401 to 402
/** @returns true if the Node is a forward node, false otherwise. **/
bool isForwardNode() const noexcept { return isForwardNode_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the default compiled binaries even train a model or is a specialized compilation needed ?

Copy link
Contributor

@skottmckay skottmckay Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this training specific and the additional code to track whether it's a forward node or not could be inside #if defined(ENABLE_TRAINING)?

nit: IsForwardNode/is_forward_node_ would be consistent with the coding standards.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the PriorityNodeCompare class in graph_viewer.cc.

Reading graph_viewer and this comment

    auto const& n1_attrs = n1->GetAttributes();
    auto const& n2_attrs = n2->GetAttributes();
    int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||

makes me curious, is this purely for visualization? The graph_viewer ist used by a lot of operations in TransformGraph. Does 'will be output first' mean for printing or is it really required for the graph transformation?

Is the information about training available at runtime as well? In this case we could pass the Information to the graph viewer and skip this expensive portion of code.

Researching further into avoiding hash computation, it probably be best to have a special key type which precomputes the hash to avoid the hash value computation, e.g.

https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1661r1.html
https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0920r2.html

@gedoensmax
Copy link
Contributor

@skottmckay as you have a lot of experience on the graph optimizer, would you have time to check this PR out or can help find someone to review ?

@@ -898,7 +871,12 @@ void Node::Init(std::string_view name,
if (attributes) {
attributes_ = *attributes;

isForwardNode_ = true;
Copy link
Contributor Author

@mtavenrath mtavenrath Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not particular happy tracking the kBackwardAttributeName at multiple places. Is there a chance to have a common pair of functions to add/remove new nodes to minimize breaking changes in the future?
#Closed

Comment on lines 401 to 402
/** @returns true if the Node is a forward node, false otherwise. **/
bool isForwardNode() const noexcept { return isForwardNode_; }
Copy link
Contributor

@skottmckay skottmckay Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this training specific and the additional code to track whether it's a forward node or not could be inside #if defined(ENABLE_TRAINING)?

nit: IsForwardNode/is_forward_node_ would be consistent with the coding standards.

include/onnxruntime/core/graph/graph.h Outdated Show resolved Hide resolved
Comment on lines 978 to 982
size_t erased = attributes_.erase(attr_name);
if (erased && attr_name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
return erased > 0;
Copy link
Contributor

@skottmckay skottmckay Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the training code and the usage of this attribute. When is the attribute added, and once added is it actually ever removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with training as well. IMHO it doesn't matter if this attribute gets removed or not in normal workflows. What matters is that it can be removed and not checking for removal would change the behaviour and potentially even introduce bugs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it doesn't seem like a standard attribute, so I would like to understand how it's used before adding more stuff to the production code to handle it theoretically changing. There's a binary size and maintenance cost, and if it's a purely internal value with specific usage it may be better to validate the usage remains as expected via unit tests.

Based on this github code search it has a very internal name, only seems to be set in a training optimizer, and only seems to be read in the graph_viewer code.

Unless there's some external usage of this magic value outside of ORT, it seems like it would be simpler for a Node to have a bool member that is directly set by the training optimizer instead of the indirect costly usage of a specially named attribute.

@askhade do you know if this special value is used outside of ORT and must be set in the Node attributes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.

@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note - when finding usage to update you need to search for both the kBackwardNodeAttributeName constant as well as the string "__backwardpass". Ideally we can make all places use the constant.

https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20kBackwardNodeAttributeName&type=code
https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20__backwardpass&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askhade I'm happy with any timeline as it is not in the a month+ timeline.

Copy link
Contributor

@pengwa pengwa Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.

@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.

FYI @askhade @skottmckay Yes, it is a purely internal used attributes, and

  1. it is firstly introduced as a backward-data-range specific perf improvement in https://github.com/microsoft/onnxruntime/blame/dfeda9019cfed2d6df5bcacc54269c7de481bdee/onnxruntime/core/providers/rocm/rocm_kernel.h#L29.

  2. A second usage pattern is: in priority based topo ordering, we consider the backward tagged node has lower priority than forward node, in training code path

    Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {
    .

While both usage are having similar logic when set the backward pass to be true, e.g:

  for (auto node_index : node_topology_list) {
    auto& node = *graph.GetNode(node_index);

    if (node.OpType() == "YieldOp") {
      is_backward_pass = true;
    }

in d5d6924#diff-8d8d103ec215ba8edb8ab23e876080adfd60f6f377084ffeca041c8b4f189a2cR13

and

// Find the YieldOp node.
  Node* yield_op_node = nullptr;
  for (auto& node : graph.Nodes()) {
    if (node.OpType() == "YieldOp") {
      yield_op_node = &node;
      break;
    }
  }

in

Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {

Plus "YieldOp" is used in ORTModule (e.g. --enable_training) build only. So I think it is possible we restrict the getforward/setforward in ENABLE_TRAINING macro. While this may need some change in onnxruntime/core/session/provider_bridge_ort.cc to wrap the new bool property ( Not sure whether your tried build/running your local code with ROCM, while I feel the change should be needed to make the ROCM ep code work.).

Copy link
Contributor

@pengwa pengwa Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I noticed for a while. Both training and inference build can use priority based order, it is indeed needed for some training features, while I don't know how much value it brings for model inferencing.

If most inference users don't have such a need, maybe we can load the nodes_in_topological_order_with_priority_ in lazy mode, e.g. we only initialize it when first time user needed it via GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED) .

@skottmckay @askhade

@mtavenrath
Copy link
Contributor Author

@jeffbloo @PatriceVignola Hi, can you take a look at this PR for further discussion?

@@ -878,6 +850,7 @@ void Node::Init(std::string_view name,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
isForwardNode_ = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isForwardNode_ = true;

Seems to be a duplicate of the same below.

@@ -1821,13 +1811,13 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from,
#if !defined(ORT_MINIMAL_BUILD)
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
std::vector<size_t> in_degree(MaxNodeIndex(), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::vector<size_t>

Prefer InlinedVector

@mtavenrath
Copy link
Contributor Author

I have updated the PR to replace the "__backward" Node attribute by Node::isForwardNode(). I haven't been able to build/test the training due to build errors on Windows, see issue #19269.

@mtavenrath
Copy link
Contributor Author

There might be a hole in the replacement of the attribute by a property. Should it be possible to save/restore a ONNX file with the newly added forward attribute? If so this change is invalid.

I had a separate version of this change where I was able to save 50% of the unordered_map lookups which saves a little bit of time which was less intrusive. In comparison to the most optimal version this adds ~1s to the 6.5s load time, yet is still ~1s faster than having the two lookups.

@askhade If ONNX is compiled for inference only, could we just put an ifdef around the expensive map lookup so that inference startups are quick?

include/onnxruntime/core/graph/graph.h Outdated Show resolved Hide resolved
include/onnxruntime/core/graph/graph.h Outdated Show resolved Hide resolved
include/onnxruntime/core/graph/graph.h Outdated Show resolved Hide resolved
include/onnxruntime/core/graph/graph.h Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph.cc Outdated Show resolved Hide resolved
onnxruntime/core/optimizer/rocm_blas_alt_impl.cc Outdated Show resolved Hide resolved
onnxruntime/core/providers/rocm/rocm_kernel.h Outdated Show resolved Hide resolved
onnxruntime/core/graph/graph_viewer.cc Outdated Show resolved Hide resolved
@skottmckay
Copy link
Contributor

There might be a hole in the replacement of the attribute by a property. Should it be possible to save/restore a ONNX file with the newly added forward attribute? If so this change is invalid.

I asked Ashwini at the time and she said there was no save/restore involving the attribute so I believe it's fine.

I had a separate version of this change where I was able to save 50% of the unordered_map lookups which saves a little bit of time which was less intrusive. In comparison to the most optimal version this adds ~1s to the 6.5s load time, yet is still ~1s faster than having the two lookups.

@askhade Ashwini Khade FTE If ONNX is compiled for inference only, could we just put an ifdef around the expensive map lookup so that inference startups are quick?

Looks like that is being done in #19475. Makes sense to me to exclude all logic looking for backwards nodes given they won't exist in a build that does not include the training code.

Comment on lines 978 to 982
size_t erased = attributes_.erase(attr_name);
if (erased && attr_name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
return erased > 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askhade I'm happy with any timeline as it is not in the a month+ timeline.

onnxruntime/core/graph/graph.cc Outdated Show resolved Hide resolved
@@ -394,6 +398,12 @@ class Node {
/** Gets the Node's attributes. */
const NodeAttributes& GetAttributes() const noexcept { return attributes_; }

/** @returns true if the Node is a forward node, false otherwise. **/
bool isForwardNode() const noexcept { return is_forward_node_; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used node.isForwardNode for now which duplicates node in the code. Would it be better to replace isForwardNode/is_forward_node_ by isForward/is_forward_?

isBackwardPass might be an even better naming

Copy link
Contributor

@skottmckay skottmckay Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think IsBackwardPass would be clearer ('backward pass' is more meaningful than 'forward'), and I agree that having 'Node' on the end of a method name of the Node class is unnecessary.

Maybe 'ForBackwardPass' would be slightly clearer than 'Is' as well.

@mtavenrath
Copy link
Contributor Author

With the changes in #19475 most of the changes in this PR are now obselete and it even improves on the PriorityQueue which was on my TODO list. I'll profile my networks again to see if any of those changes are still required.

@mtavenrath
Copy link
Contributor Author

Mostly superseded by #19475. Closing.

@mtavenrath mtavenrath closed this Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants