-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize onnxruntime::InferenceSession::Initialize with focus on GrapViewer. For large models the speedup of this function can be up to 3x. #19080
Conversation
…hViewer. For large models the speedup of this function can be up to 3x.
@microsoft-github-policy-service agree company="NVIDIA" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gedoensmax Can you take a look at this change please?
onnxruntime/core/graph/graph.cc
Outdated
@@ -878,6 +850,7 @@ void Node::Init(std::string_view name, | |||
gsl::span<NodeArg* const> output_args, | |||
const NodeAttributes* attributes, | |||
std::string_view domain) { | |||
isForwardNode_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating the forward node on each change is kind of ugly and risky. With regards to performance it should be unproblematic since the string compare is quite efficient doing a length check first before actually comparing bytes.
The reason for caching isForwardNode_ is each check for the kBackwardNodeAttributeName within the PriorityNodeCompare actually computes the hash of the string which is the costly part.
Alternate solutions would change the container and precompute the hash of kBackwardNodeAttributeName and incorperate it into the key. The downside with this solution is that it'd work only for cases where colisions are handled by lists instead of multiple iterations of hash keys. #Closed
int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || | ||
auto const& n1_attrs = n1->GetAttributes(); | ||
auto const& n2_attrs = n2->GetAttributes(); | ||
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change shaves off nearly 2s from the original 20s of TransformGraph and will be maximum efficient when loading models for inference.
For training models the hash computation is still done once. Before the hash has potentially been computed twice, one time for find and one time for at. By fetching the iterator of std::fine and using it later to get the value of i the number of hash computation could be halved (saving ~1s instead of 2s).
If it's known that the model is used for inference only it'd be great if PriorityNodeCompare could skip this test altogether. This could be achieved most efficient by making this a template class and specialize for inference / training.
@yufenglee can you help to find a reviewer for this ? |
/** @returns true if the Node is a forward node, false otherwise. **/ | ||
bool isForwardNode() const noexcept { return isForwardNode_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the default compiled binaries even train a model or is a specialized compilation needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this training specific and the additional code to track whether it's a forward node or not could be inside #if defined(ENABLE_TRAINING)?
nit: IsForwardNode/is_forward_node_ would be consistent with the coding standards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for the PriorityNodeCompare class in graph_viewer.cc.
Reading graph_viewer and this comment
auto const& n1_attrs = n1->GetAttributes();
auto const& n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||
makes me curious, is this purely for visualization? The graph_viewer ist used by a lot of operations in TransformGraph. Does 'will be output first' mean for printing or is it really required for the graph transformation?
Is the information about training available at runtime as well? In this case we could pass the Information to the graph viewer and skip this expensive portion of code.
Researching further into avoiding hash computation, it probably be best to have a special key type which precomputes the hash to avoid the hash value computation, e.g.
https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1661r1.html
https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0920r2.html
@skottmckay as you have a lot of experience on the graph optimizer, would you have time to check this PR out or can help find someone to review ? |
onnxruntime/core/graph/graph.cc
Outdated
@@ -898,7 +871,12 @@ void Node::Init(std::string_view name, | |||
if (attributes) { | |||
attributes_ = *attributes; | |||
|
|||
isForwardNode_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not particular happy tracking the kBackwardAttributeName at multiple places. Is there a chance to have a common pair of functions to add/remove new nodes to minimize breaking changes in the future?
#Closed
/** @returns true if the Node is a forward node, false otherwise. **/ | ||
bool isForwardNode() const noexcept { return isForwardNode_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this training specific and the additional code to track whether it's a forward node or not could be inside #if defined(ENABLE_TRAINING)?
nit: IsForwardNode/is_forward_node_ would be consistent with the coding standards.
onnxruntime/core/graph/graph.cc
Outdated
size_t erased = attributes_.erase(attr_name); | ||
if (erased && attr_name == kBackwardNodeAttributeName) { | ||
isForwardNode_ = true; | ||
} | ||
return erased > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with the training code and the usage of this attribute. When is the attribute added, and once added is it actually ever removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with training as well. IMHO it doesn't matter if this attribute gets removed or not in normal workflows. What matters is that it can be removed and not checking for removal would change the behaviour and potentially even introduce bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case it doesn't seem like a standard attribute, so I would like to understand how it's used before adding more stuff to the production code to handle it theoretically changing. There's a binary size and maintenance cost, and if it's a purely internal value with specific usage it may be better to validate the usage remains as expected via unit tests.
Based on this github code search it has a very internal name, only seems to be set in a training optimizer, and only seems to be read in the graph_viewer code.
Unless there's some external usage of this magic value outside of ORT, it seems like it would be simpler for a Node to have a bool member that is directly set by the training optimizer instead of the indirect costly usage of a specially named attribute.
@askhade do you know if this special value is used outside of ORT and must be set in the Node attributes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.
@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One note - when finding usage to update you need to search for both the kBackwardNodeAttributeName constant as well as the string "__backwardpass". Ideally we can make all places use the constant.
https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20kBackwardNodeAttributeName&type=code
https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20__backwardpass&type=code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@askhade I'm happy with any timeline as it is not in the a month+ timeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.
@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.
FYI @askhade @skottmckay Yes, it is a purely internal used attributes, and
-
it is firstly introduced as a backward-data-range specific perf improvement in https://github.com/microsoft/onnxruntime/blame/dfeda9019cfed2d6df5bcacc54269c7de481bdee/onnxruntime/core/providers/rocm/rocm_kernel.h#L29.
-
A second usage pattern is: in priority based topo ordering, we consider the backward tagged node has lower priority than forward node, in training code path
onnxruntime/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
Line 172 in dfeda90
Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {
While both usage are having similar logic when set the backward pass to be true, e.g:
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "YieldOp") {
is_backward_pass = true;
}
in d5d6924#diff-8d8d103ec215ba8edb8ab23e876080adfd60f6f377084ffeca041c8b4f189a2cR13
and
// Find the YieldOp node.
Node* yield_op_node = nullptr;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "YieldOp") {
yield_op_node = &node;
break;
}
}
in
onnxruntime/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
Line 172 in dfeda90
Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { |
Plus "YieldOp" is used in ORTModule (e.g. --enable_training) build only. So I think it is possible we restrict the getforward/setforward in ENABLE_TRAINING macro. While this may need some change in onnxruntime/core/session/provider_bridge_ort.cc to wrap the new bool property ( Not sure whether your tried build/running your local code with ROCM, while I feel the change should be needed to make the ROCM ep code work.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I noticed for a while. Both training and inference build can use priority based order, it is indeed needed for some training features, while I don't know how much value it brings for model inferencing.
If most inference users don't have such a need, maybe we can load the nodes_in_topological_order_with_priority_
in lazy mode, e.g. we only initialize it when first time user needed it via GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED)
.
@jeffbloo @PatriceVignola Hi, can you take a look at this PR for further discussion? |
onnxruntime/core/graph/graph.cc
Outdated
@@ -878,6 +850,7 @@ void Node::Init(std::string_view name, | |||
gsl::span<NodeArg* const> output_args, | |||
const NodeAttributes* attributes, | |||
std::string_view domain) { | |||
isForwardNode_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
onnxruntime/core/graph/graph.cc
Outdated
@@ -1821,13 +1811,13 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from, | |||
#if !defined(ORT_MINIMAL_BUILD) | |||
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter, | |||
const std::function<bool(const Node*, const Node*)>& comp) const { | |||
std::unordered_map<NodeIndex, size_t> in_degree; | |||
std::vector<size_t> in_degree(MaxNodeIndex(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the PR to replace the "__backward" Node attribute by |
a1a68ed
to
6673a0c
Compare
There might be a hole in the replacement of the attribute by a property. Should it be possible to save/restore a ONNX file with the newly added forward attribute? If so this change is invalid. I had a separate version of this change where I was able to save 50% of the unordered_map lookups which saves a little bit of time which was less intrusive. In comparison to the most optimal version this adds ~1s to the 6.5s load time, yet is still ~1s faster than having the two lookups. @askhade If ONNX is compiled for inference only, could we just put an ifdef around the expensive map lookup so that inference startups are quick? |
orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
Outdated
Show resolved
Hide resolved
I asked Ashwini at the time and she said there was no save/restore involving the attribute so I believe it's fine.
Looks like that is being done in #19475. Makes sense to me to exclude all logic looking for backwards nodes given they won't exist in a build that does not include the training code. |
Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Scott McKay <[email protected]>
onnxruntime/core/graph/graph.cc
Outdated
size_t erased = attributes_.erase(attr_name); | ||
if (erased && attr_name == kBackwardNodeAttributeName) { | ||
isForwardNode_ = true; | ||
} | ||
return erased > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@askhade I'm happy with any timeline as it is not in the a month+ timeline.
@@ -394,6 +398,12 @@ class Node { | |||
/** Gets the Node's attributes. */ | |||
const NodeAttributes& GetAttributes() const noexcept { return attributes_; } | |||
|
|||
/** @returns true if the Node is a forward node, false otherwise. **/ | |||
bool isForwardNode() const noexcept { return is_forward_node_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used node.isForwardNode
for now which duplicates node
in the code. Would it be better to replace isForwardNode/is_forward_node_
by isForward/is_forward_
?
isBackwardPass
might be an even better naming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think IsBackwardPass would be clearer ('backward pass' is more meaningful than 'forward'), and I agree that having 'Node' on the end of a method name of the Node class is unnecessary.
Maybe 'ForBackwardPass' would be slightly clearer than 'Is' as well.
With the changes in #19475 most of the changes in this PR are now obselete and it even improves on the PriorityQueue which was on my TODO list. I'll profile my networks again to see if any of those changes are still required. |
Mostly superseded by #19475. Closing. |
Description
Motivation and Context
Load time of the huggingface 7b llama files can be quite long (>45s). 50% of this time is IO while the other 50% of the time was TransformGraph. To decrease loading time and thus startup time it'd be good to optimize the non-IO part as much as possible.