From 773e9b5fba875d766aed843e84880d3b5b86e976 Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Wed, 8 Aug 2018 13:45:04 +0200 Subject: [PATCH 1/2] add support for GPU memory query --- include/mxnet/base.h | 37 +++++++++++++++++++++++++++++++++++++ include/mxnet/c_api.h | 9 +++++++++ src/c_api/c_api.cc | 6 ++++++ 3 files changed, 52 insertions(+) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index a652fe5b7073..5dff392059c2 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -222,6 +222,14 @@ struct Context { * \return The number of GPUs that are available. */ inline static int32_t GetGPUCount(); + /*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the integer holding free GPU memory + * \param total_mem pointer to the integer holding total GPU memory + * \return No return value + */ + inline static void GetGPUMemoryInformation(int dev, int *free, int *total); /*! * Create a pinned CPU context. * \param dev_id the device id for corresponding GPU. @@ -326,6 +334,35 @@ inline int32_t Context::GetGPUCount() { #endif } + +inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { + #if MXNET_USE_CUDA + + size_t memF, memT; + cudaError_t e; + + int curDevice; + e = cudaGetDevice(&curDevice); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaSetDevice(dev); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaMemGetInfo(&memF, &memT); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + e = cudaSetDevice(curDevice); + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + + *free_mem = static_cast(memF); + *total_mem = static_cast(memT); + +#else + LOG(FATAL) + << "This call is only supported for MXNet built with CUDA support."; +#endif +} + inline Context Context::FromString(const std::string& str) { Context ret; try { diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 75147cfd706d..d1cd0a1315bf 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -390,6 +390,15 @@ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); */ MXNET_DLL int MXGetGPUCount(int* out); +/*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the integer holding free GPU memory + * \param total_mem pointer to the integer holding total GPU memory + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem); + /*! * \brief get the MXNet library version as an integer * \param pointer to the integer holding the version number diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ed513c0d7785..1ef3f0fca9f3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -122,6 +122,12 @@ int MXGetGPUCount(int* out) { API_END(); } +int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { + API_BEGIN(); + Context::GetGPUMemoryInformation(dev, free_mem, total_mem); + API_END(); +} + int MXGetVersion(int *out) { API_BEGIN(); *out = static_cast(MXNET_VERSION); From 2b11ff74d33b235046006e475ebe91c79bb15c81 Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Wed, 8 Aug 2018 13:53:37 +0200 Subject: [PATCH 2/2] remove lint --- include/mxnet/base.h | 6 +++--- include/mxnet/c_api.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 5dff392059c2..75784a391b47 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -334,10 +334,10 @@ inline int32_t Context::GetGPUCount() { #endif } +inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, + int *total_mem) { +#if MXNET_USE_CUDA -inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { - #if MXNET_USE_CUDA - size_t memF, memT; cudaError_t e; diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d1cd0a1315bf..5a24cc0b944e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -397,7 +397,7 @@ MXNET_DLL int MXGetGPUCount(int* out); * \param total_mem pointer to the integer holding total GPU memory * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem); +MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem); /*! * \brief get the MXNet library version as an integer