Skip to content

Commit 4547cc4

Browse files
iblislinaaronmarkham
authored andcommitted
[MXNET-1430] julia: implement context.gpu_memory_info (apache#16324)
* julia: implement context.gpu_memory_info resolve MXNET-1430 * update export and NEWS
1 parent d948256 commit 4547cc4

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

julia/NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
* Add an abstract type `AbstractMXError` as the parent type for all MXNet-related
2121
API errors. (#16235)
2222

23+
* Porting more `context` functions from Python.
24+
* `num_gpus()` (#16236)
25+
* `gpu_memory_info()` (#16324)
26+
2327

2428
# v1.5.0
2529

julia/src/MXNet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ export Executor,
7979
export Context,
8080
cpu,
8181
gpu,
82-
num_gpus
82+
num_gpus,
83+
gpu_memory_info
8384

8485
# model.jl
8586
export AbstractModel,

julia/src/context.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,21 @@ function num_gpus()
6868
@mxcall :MXGetGPUCount (Ref{Cint},) n
6969
n[]
7070
end
71+
72+
"""
73+
gpu_memory_info(dev_id = 0)::Tuple{UInt64,UInt64}
74+
75+
Query CUDA for the free and total bytes of GPU global memory.
76+
It returns a tuple of `(free memory, total memory)`.
77+
78+
```julia-repl
79+
julia> mx.gpu_memory_info()
80+
(0x00000003af240000, 0x00000003f9440000)
81+
```
82+
"""
83+
function gpu_memory_info(dev_id = 0)
84+
free = Ref{UInt64}()
85+
n = Ref{UInt64}()
86+
@mxcall :MXGetGPUMemoryInformation64 (Cint, Ref{UInt64}, Ref{UInt64}) dev_id free n
87+
free[], n[]
88+
end

python/mxnet/context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ def gpu_memory_info(device_id=0):
292292
Returns
293293
-------
294294
(free, total) : (int, int)
295-
The number of GPUs.
296-
297295
"""
298296
free = ctypes.c_uint64()
299297
total = ctypes.c_uint64()

0 commit comments

Comments
 (0)