Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

GPU Memory Query to C API #12083

Merged
merged 2 commits into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ struct Context {
* \return The number of GPUs that are available.
*/
inline static int32_t GetGPUCount();
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \return No return value
*/
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -326,6 +334,35 @@ inline int32_t Context::GetGPUCount() {
#endif
}

inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
int *total_mem) {
#if MXNET_USE_CUDA

size_t memF, memT;
cudaError_t e;

int curDevice;
e = cudaGetDevice(&curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaSetDevice(dev);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaMemGetInfo(&memF, &memT);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

e = cudaSetDevice(curDevice);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);

*free_mem = static_cast<int>(memF);
*total_mem = static_cast<int>(memT);

#else
LOG(FATAL)
<< "This call is only supported for MXNet built with CUDA support.";
#endif
}

inline Context Context::FromString(const std::string& str) {
Context ret;
try {
Expand Down
9 changes: 9 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,15 @@ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
*/
MXNET_DLL int MXGetGPUCount(int* out);

/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the integer holding free GPU memory
* \param total_mem pointer to the integer holding total GPU memory
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);

/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
Expand Down
6 changes: 6 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ int MXGetGPUCount(int* out) {
API_END();
}

int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
API_BEGIN();
Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
API_END();
}

int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
Expand Down