Skip to content

Commit

Permalink
GPU Memory Query to C API (apache#12083)
Browse files Browse the repository at this point in the history
* add support for GPU memory query

* remove lint
  • Loading branch information
sbodenstein authored and sandeep-krishnamurthy committed Aug 10, 2018
1 parent 64a92e8 commit e1f98b8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
37 changes: 37 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ struct Context {
* \return The number of GPUs that are available.
*/
inline static int32_t GetGPUCount();
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \return No return value
*/
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -326,6 +334,35 @@ inline int32_t Context::GetGPUCount() {
#endif
}

inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
int *total_mem) {
#if MXNET_USE_CUDA

size_t memF, memT;
cudaError_t e;

int curDevice;
e = cudaGetDevice(&curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaSetDevice(dev);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaMemGetInfo(&memF, &memT);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaSetDevice(curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

*free_mem = static_cast<int>(memF);
*total_mem = static_cast<int>(memT);

#else
LOG(FATAL)
<< "This call is only supported for MXNet built with CUDA support.";
#endif
}

inline Context Context::FromString(const std::string& str) {
Context ret;
try {
Expand Down
9 changes: 9 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,15 @@ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
*/
MXNET_DLL int MXGetGPUCount(int* out);

/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);

/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
Expand Down
6 changes: 6 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ int MXGetGPUCount(int* out) {
API_END();
}

int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
API_BEGIN();
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
API_END();
}

int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
Expand Down

0 comments on commit e1f98b8

Please sign in to comment.