diff --git a/include/mxnet/base.h b/include/mxnet/base.h index a652fe5b7073..75784a391b47 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -222,6 +222,14 @@ struct Context { * \return The number of GPUs that are available. */ inline static int32_t GetGPUCount(); + /*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the integer holding free GPU memory + * \param total_mem pointer to the integer holding total GPU memory + * \return No return value + */ + inline static void GetGPUMemoryInformation(int dev, int *free, int *total); /*! * Create a pinned CPU context. * \param dev_id the device id for corresponding GPU. @@ -326,6 +334,35 @@ inline int32_t Context::GetGPUCount() { #endif } +inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, + int *total_mem) { +#if MXNET_USE_CUDA + + size_t memF, memT; + cudaError_t e; + + int curDevice; + e = cudaGetDevice(&curDevice); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaSetDevice(dev); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaMemGetInfo(&memF, &memT); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaSetDevice(curDevice); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + *free_mem = static_cast(memF); + *total_mem = static_cast(memT); + +#else + LOG(FATAL) + << "This call is only supported for MXNet built with CUDA support."; +#endif +} + inline Context Context::FromString(const std::string& str) { Context ret; try { diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 75147cfd706d..5a24cc0b944e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -390,6 +390,15 @@ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); */ MXNET_DLL int MXGetGPUCount(int* out); +/*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the integer holding free GPU memory + * \param total_mem pointer to the integer holding total GPU memory + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem); + /*! * \brief get the MXNet library version as an integer * \param pointer to the integer holding the version number diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ed513c0d7785..1ef3f0fca9f3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -122,6 +122,12 @@ int MXGetGPUCount(int* out) { API_END(); } +int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { + API_BEGIN(); + Context::GetGPUMemoryInformation(dev, free_mem, total_mem); + API_END(); +} + int MXGetVersion(int *out) { API_BEGIN(); *out = static_cast(MXNET_VERSION);