-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Performance Regression] GPU memory increase for training and inference models #18280
Comments
@mxnet-label-bot update [Performance] |
Moving the details from the offline discussion. The problem here is that when loading library containing GPU kernels, those kernels are stored in GPU memory. So the more kernels we have in MXNet library, 2 things happen:
To fully resolve this, I do not believe it is feasible to rely on template instantiation during compilation to generate those simple kernels, but instead to move towards runtime compilation. This has a series of advantages:
This is not 100% silver bullet though, as it also has its downsides:
I believe that the benefits outweigh the downsides, especially with the push for DeepNumpy and the need to support different type combinations, which prompted already to disable some of them in Windows builds because of those issues (see e.g. https://github.com/apache/incubator-mxnet/blob/f00b9ab5b4410a91a8f6581da696a92f85fbccf6/src/operator/numpy/np_elemwise_broadcast_op.cu#L32-L41) @apeforest @leezu @szha @eric-haibin-lin @sxjscience What are your thoughts about this? |
I agree with using JIT to resolve the binary size issue. @ptrendx @karan6181 Accoding to the reproducible example, does it mean that there is a 128 MB GPU memory increase for just storing the GPU codes? |
Yes, that is right - in total we lose 1407 MB of GPU memory just for loading the library (it is not just the kernel code being loaded, in this there is I believe ~400 MB of context, but majority is the kernel code). |
@ptrendx would your proposal be in scope for 2.0? One further advantage is that by performing compilation at runtime, we may be able to avoid libmxnet.so to become subject to the CUDA EULA as long as nvrtc.h could be licensed under a compatible license |
@leezu That would require doing everything CUDA-related as RTC, whereas my proposal was currently limited only to portion of the kernels. I would like to start working on this RTC approach ~now, so yes, it is definitely in scope for 2.0 :-). |
@leezu @apeforest If anybody wants to help with the build side of things for the dynamic loading of |
On the CMake side: What else do you expect is required besides removing the declaration that libmxnet depends on libcuda and libnvrtc? It's not clear to me yet why further changes would be needed. In either case, I'm happy to help on the build side. |
For the dynamic loading of libnvrtc that would be it I think, I was also thinking about something that would prevent code duplication (something like some answers to this SO question: https://stackoverflow.com/questions/410980/include-a-text-file-in-a-c-program-as-a-char), since that would require generating some files during build. |
I see. To run |
|
Description
M: 1.47k to 1.64k
M: 2.0k to 2.16k
M: 1.84k to 2.0k
M: 1.55k to 1.71k
M: 1.48k to 1.64k
M: 1.86k to 2.02k
M: 1.40 to 1.56k
M: 1.54k to 1.70k
M: 1.31k to 1.47k
To Reproduce
nvidia-smi
command.import mxnet as mx a = mx.nd.zeros((1,), ctx=mx.gpu())
Output:
The text was updated successfully, but these errors were encountered: