diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl index 89ec88b52cdc..5a146396642b 100644 --- a/julia/src/MXNet.jl +++ b/julia/src/MXNet.jl @@ -74,7 +74,8 @@ export Executor, # context.jl export Context, cpu, - gpu + gpu, + num_gpus # model.jl export AbstractModel, diff --git a/julia/src/context.jl b/julia/src/context.jl index 71aee3020daa..bce67a593414 100644 --- a/julia/src/context.jl +++ b/julia/src/context.jl @@ -57,3 +57,14 @@ Get a GPU context with a specific id. The K GPUs on a node is typically numbered * `dev_id::Integer = 0` the GPU device id. """ gpu(dev_id::Integer = 0) = Context(GPU, dev_id) + +""" + num_gpus() + +Query CUDA for the number of GPUs present. +""" +function num_gpus() + n = Ref{Cint}() + @mxcall :MXGetGPUCount (Ref{Cint},) n + n[] +end diff --git a/julia/test/unittest/context.jl b/julia/test/unittest/context.jl new file mode 100644 index 000000000000..0a8f086a194a --- /dev/null +++ b/julia/test/unittest/context.jl @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module TestContext + +using MXNet +using Test + +function test_num_gpus() + @info "Context::num_gpus" + + @test num_gpus() >= 0 +end + +@testset "Context Test" begin + test_num_gpus() +end + + +end # module TestContext diff --git a/python/mxnet/context.py b/python/mxnet/context.py index f284e00127b4..f2b01373b280 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -276,6 +276,7 @@ def num_gpus(): check_call(_LIB.MXGetGPUCount(ctypes.byref(count))) return count.value + def gpu_memory_info(device_id=0): """Query CUDA for the free and total bytes of GPU global memory. @@ -300,6 +301,7 @@ def gpu_memory_info(device_id=0): check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free), ctypes.byref(total))) return (free.value, total.value) + def current_context(): """Returns the current context.