Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for setGraphExecutorOptimize with torchscript models. #904

Merged
merged 4 commits into from
Apr 27, 2021

Conversation

hodovo
Copy link
Contributor

@hodovo hodovo commented Apr 27, 2021

Description

This PR adds support for torch::jit::setGraphExecutorOptimize which allows the user to prevent model "warmup" periods while torchscript optimizes the model on GPU.

Users can disable the torchscript optimization with the following code:

JniUtils.setGraphExecutorOptimize(false);

Since this feature is enabled by default in torchscript, it will only disable optimization when the the method above is called. Since JNI maintains an individual environment per thread, the above method must be called for each thread that is using the model which optimization should be disabled for.

This change is backwards compatible and does not alter the usage of any existing code.

@hodovo
Copy link
Contributor Author

hodovo commented Apr 27, 2021

@frankfliu I created it as a draft pull request as this is my first pull request. Please have a look and you can mark it ready for review.

@hodovo hodovo marked this pull request as ready for review April 27, 2021 07:51
@hodovo
Copy link
Contributor Author

hodovo commented Apr 27, 2021

After doing some testing, there is currently a bug with multi-threaded inference where it will have the same multi-second delay after a few inferences. It does not happen single threaded.

I have fixed it by running PyTorchLibrary.LIB.setGraphExecutorOptimize(false); before every time PyTorchLibrary.LIB.moduleForward is called. This indicates that when using multiple threads, the value of the setGraphExecutorOptimize is being reset in the JNI code or it is not being respected.

I propose that we add another parameter to PyTorchLibrary.LIB.moduleForward which takes in whether setGraphExecutorOptimize is true or false and sets it each time.

I am going to continue investigating to see if I can figure out why the value is not being respected when set in the PtSymbolicBlock on the first inference and only when using more than one thread.

@hodovo
Copy link
Contributor Author

hodovo commented Apr 27, 2021

I was able to resolve this by calling JniUtils.setGraphExecutorOptimize(false); in each thread I planned to use. I do not have much experience with JNI, but it seems that because each thread provides its own environment to each JNI call, this causes the reset.

I think the best course of action would to be add a parameter to PyTorchLibrary.LIB.moduleForward as I mentioned above. Let me know what you think.

… since this is not respected in a multi-threaded environment.
@hodovo
Copy link
Contributor Author

hodovo commented Apr 27, 2021

It might be a good idea to just allow the developer to utilize the JNI function how they want rather than focusing it one way or another. This flexibility will be needed in some situations, like if you want to load models on one thread with optimization and another set of models without optimization. With this global approach, it is much less flexible and it doesn't enforce any changes.

I just reverted my changes to PtSymbolBlock as I think flexibility is the way to go.

Change-Id: I59b7ef2e2b24543d34a9c15e73add232ef55afc6
@codecov-commenter
Copy link

codecov-commenter commented Apr 27, 2021

Codecov Report

Merging #904 (6ddac4b) into master (e819389) will decrease coverage by 0.01%.
The diff coverage is 0.00%.

Impacted file tree graph

@@             Coverage Diff              @@
##             master     #904      +/-   ##
============================================
- Coverage     70.34%   70.32%   -0.02%     
  Complexity     5085     5085              
============================================
  Files           501      501              
  Lines         22432    22437       +5     
  Branches       2332     2335       +3     
============================================
  Hits          15779    15779              
- Misses         5412     5417       +5     
  Partials       1241     1241              
Impacted Files Coverage Δ Complexity Δ
...ine/src/main/java/ai/djl/pytorch/jni/JniUtils.java 90.57% <0.00%> (-0.30%) 194.00 <0.00> (ø)
...c/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java 100.00% <ø> (ø) 1.00 <0.00> (ø)
api/src/main/java/ai/djl/util/Platform.java 45.65% <0.00%> (-3.19%) 8.00% <0.00%> (ø%)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e819389...6ddac4b. Read the comment docs.

@lanking520 lanking520 merged commit 91209ae into deepjavalibrary:master Apr 27, 2021
@hodovo hodovo deleted the torchscript-optimize branch May 1, 2021 20:13
Lokiiiiii pushed a commit to Lokiiiiii/djl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants