Skip to content

Commit 2d25593

Browse files
authored
Merge pull request #185 from omlins/memoptparams
Set memopt parameters in function of compute capability
2 parents b251536 + 79f32b7 commit 2d25593

File tree

12 files changed

+121
-43
lines changed

12 files changed

+121
-43
lines changed

src/ParallelKernel/AMDGPUExt/defaults.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const ERRMSG_AMDGPUEXT_NOT_LOADED = "the AMDGPU extension was not loaded. Make s
55

66
function get_priority_rocstream end
77
function get_rocstream end
8+
function get_amdgpu_compute_capability end
89

910

1011
# allocators.jl

src/ParallelKernel/AMDGPUExt/shared.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,13 @@ let
2828
while (id > length(rocstreams)) push!(rocstreams, AMDGPU.HIPStream(:low)) end
2929
return rocstreams[id]
3030
end
31-
end
31+
end
32+
33+
34+
## FUNCTIONS TO QUERY DEVICE PROPERTIES
35+
36+
function ParallelStencil.ParallelKernel.get_amdgpu_compute_capability(default::VersionNumber)
37+
compute_capability = default
38+
#TODO: implement and convert to something comparable to CUDA compute capability.
39+
return compute_capability
40+
end

src/ParallelKernel/CUDAExt/defaults.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const ERRMSG_CUDAEXT_NOT_LOADED = "the CUDA extension was not loaded. Make sure
55

66
function get_priority_custream end
77
function get_custream end
8+
function get_cuda_compute_capability end
89

910

1011
# allocators.jl
@@ -15,4 +16,4 @@ rand_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
1516
falses_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
1617
trues_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
1718
fill_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
18-
fill_cuda!(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
19+
fill_cuda!(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)

src/ParallelKernel/CUDAExt/shared.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,22 @@ let
2828
while (id > length(custreams)) push!(custreams, CuStream(; flags=CUDA.STREAM_NON_BLOCKING, priority=CUDA.priority_range()[1])) end # CUDA.priority_range()[1] is min priority. # NOTE: priority_range cannot be called outside the function as only at runtime sure that CUDA is functional.
2929
return custreams[id]
3030
end
31-
end
31+
end
32+
33+
34+
## FUNCTIONS TO QUERY DEVICE PROPERTIES
35+
36+
function ParallelStencil.ParallelKernel.get_cuda_compute_capability(default::VersionNumber)
37+
compute_capability = default
38+
if haskey(ENV, "PS_CUDA_COMPUTE_CAPABILITY")
39+
compute_capability = parse(VersionNumber, ENV["PS_CUDA_COMPUTE_CAPABILITY"])
40+
else
41+
try
42+
dev = CUDA.device()
43+
compute_capability = CUDA.capability(dev)
44+
catch e
45+
@warn "Could not determine CUDA compute capability: assuming a recent architecture. Set the environment variable PS_CUDA_COMPUTE_CAPABILITY to a specific value if desired."
46+
end
47+
end
48+
return compute_capability
49+
end

src/ParallelKernel/MetalExt/defaults.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ const ERRMSG_METALEXT_NOT_LOADED = "the Metal extension was not loaded. Make sur
44

55
function get_priority_metalstream end
66
function get_metalstream end
7+
function get_metal_compute_capability end
8+
79

810
# allocators
911

@@ -14,5 +16,3 @@ falses_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
1416
trues_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
1517
fill_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
1618
fill_metal!(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
17-
18-

src/ParallelKernel/MetalExt/shared.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ import Metal.MTL
66

77
@define_MtlCellArray
88

9+
910
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT
11+
1012
ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_MetalExt}) = true
1113

14+
1215
## FUNCTIONS TO GET CREATE AND MANAGE METAL QUEUES
16+
1317
ParallelStencil.ParallelKernel.get_priority_metalstream(arg...) = get_priority_metalstream(arg...)
1418
ParallelStencil.ParallelKernel.get_metalstream(arg...) = get_metalstream(arg...)
1519

@@ -27,4 +31,13 @@ let
2731
while (id > length(metalqueues)) push!(metalqueues, MTL.MTLCommandQueue(Metal.device())) end
2832
return metalqueues[id]
2933
end
30-
end
34+
end
35+
36+
37+
## FUNCTIONS TO QUERY DEVICE PROPERTIES
38+
39+
function ParallelStencil.ParallelKernel.get_metal_compute_capability(default::VersionNumber)
40+
compute_capability = default
41+
#TODO: implement and convert to something comparable to CUDA compute capability.
42+
return compute_capability
43+
end

src/ParallelKernel/parallel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ function add_threadids(indices::Array, ranges::Array, block::Expr)
555555
end
556556
end
557557
quote
558-
$tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
558+
$tx = ((ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + 1) + ParallelStencil.ParallelKernel.@threadIdx().x - 1; # thread ID, dimension x #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
559559
$thread_bounds_check
560560
$ix = $range_x[$tx] # index, dimension x
561561
$block
@@ -570,8 +570,8 @@ function add_threadids(indices::Array, ranges::Array, block::Expr)
570570
end
571571
end
572572
quote
573-
$tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
574-
$ty = (ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + ParallelStencil.ParallelKernel.@threadIdx().y; # thread ID, dimension y
573+
$tx = ((ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + 1) + ParallelStencil.ParallelKernel.@threadIdx().x - 1; # thread ID, dimension x #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
574+
$ty = ((ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + 1) + ParallelStencil.ParallelKernel.@threadIdx().y - 1; # thread ID, dimension y #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $ty = (ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + ParallelStencil.ParallelKernel.@threadIdx().y; # thread ID, dimension y
575575
$thread_bounds_check
576576
$ix = $range_x[$tx] # index, dimension x
577577
$iy = $range_y[$ty] # index, dimension y
@@ -588,9 +588,9 @@ function add_threadids(indices::Array, ranges::Array, block::Expr)
588588
end
589589
end
590590
quote
591-
$tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
592-
$ty = (ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + ParallelStencil.ParallelKernel.@threadIdx().y; # thread ID, dimension y
593-
$tz = (ParallelStencil.ParallelKernel.@blockIdx().z-1) * ParallelStencil.ParallelKernel.@blockDim().z + ParallelStencil.ParallelKernel.@threadIdx().z; # thread ID, dimension z
591+
$tx = ((ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + 1) + ParallelStencil.ParallelKernel.@threadIdx().x - 1; # thread ID, dimension x #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $tx = (ParallelStencil.ParallelKernel.@blockIdx().x-1) * ParallelStencil.ParallelKernel.@blockDim().x + ParallelStencil.ParallelKernel.@threadIdx().x; # thread ID, dimension x
592+
$ty = ((ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + 1) + ParallelStencil.ParallelKernel.@threadIdx().y - 1; # thread ID, dimension y #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $ty = (ParallelStencil.ParallelKernel.@blockIdx().y-1) * ParallelStencil.ParallelKernel.@blockDim().y + ParallelStencil.ParallelKernel.@threadIdx().y; # thread ID, dimension y
593+
$tz = ((ParallelStencil.ParallelKernel.@blockIdx().z-1) * ParallelStencil.ParallelKernel.@blockDim().z + 1) + ParallelStencil.ParallelKernel.@threadIdx().z - 1; # thread ID, dimension z #NOTE: the addition and subtraction is a trick to reduce register pressure due to Int64 indexing; normally it would simply be: $tz = (ParallelStencil.ParallelKernel.@blockIdx().z-1) * ParallelStencil.ParallelKernel.@blockDim().z + ParallelStencil.ParallelKernel.@threadIdx().z; # thread ID, dimension z
594594
$thread_bounds_check
595595
$ix = $range_x[$tx] # index, dimension x
596596
$iy = $range_y[$ty] # index, dimension y

src/ParallelKernel/shared.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const INT_AMDGPU = Int64 # NOTE: ...
2626
const INT_METAL = Int64 # NOTE: ...
2727
const INT_POLYESTER = Int64 # NOTE: ...
2828
const INT_THREADS = Int64 # NOTE: ...
29+
const COMPUTE_CAPABILITY_DEFAULT = v"" # having it infinity if it is not set allows to directly use statements like `if compute_capability < v"8"`, assuming a recent architecture if it is not set.
2930
const NTHREADS_X_MAX = 32
3031
const NTHREADS_X_MAX_AMDGPU = 64
3132
const NTHREADS_MAX = 256
@@ -572,6 +573,23 @@ interpolate(sym::Symbol, vals_expr::Expr, block::Expr) = interpolate(sym, (extra
572573
quote_expr(expr) = :($(Expr(:quote, expr)))
573574

574575

576+
## FUNCTIONS TO QUERY DEVICE PROPERTIES
577+
578+
function get_compute_capability(package::Symbol)
579+
default = COMPUTE_CAPABILITY_DEFAULT
580+
if (package == PKG_CUDA) get_cuda_compute_capability(default)
581+
elseif (package == PKG_AMDGPU) get_amdgpu_compute_capability(default)
582+
elseif (package == PKG_METAL) get_metal_compute_capability(default)
583+
elseif (package == PKG_THREADS) get_cpu_compute_capability(default)
584+
elseif (package == PKG_POLYESTER) get_cpu_compute_capability(default)
585+
else
586+
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package). Supported packages are: $(join(SUPPORTED_PACKAGES, ", ")).")
587+
end
588+
end
589+
590+
get_cpu_compute_capability(default::VersionNumber) = return default
591+
592+
575593
## FUNCTIONS/MACROS FOR DIVERSE SYNTAX SUGAR
576594

577595
iscpu(package) = return (package in (PKG_THREADS, PKG_POLYESTER))

src/kernel_language.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
7878
offsets, offsets_by_z = extract_offsets(caller, body, indices, int_type, optvars, loopdim)
7979
optvars = remove_single_point_optvars(optvars, optranges, offsets, offsets_by_z)
8080
if (length(optvars)==0) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] <kernel>: optimization can only be applied if there is at least one array that is read-only within the kernel (and accessed with a multi-point stencil). Set memopt=false for this kernel.") end
81-
optranges = define_optranges(optranges, optvars, offsets, int_type)
81+
optranges = define_optranges(optranges, optvars, offsets, int_type, package)
8282
regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails = define_regqueues(offsets, optranges, optvars, indices, int_type, loopdim)
8383

8484
if loopdim == 3
@@ -102,6 +102,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
102102
ranges = RANGES_VARNAME
103103
range_z = :(($ranges[3])[$tz_g])
104104
range_z_start = :(($ranges[3])[1])
105+
range_z_end = :(($ranges[3])[end])
105106
i = gensym_world("i", @__MODULE__)
106107
loopoffset = gensym_world("loopoffset", @__MODULE__)
107108

@@ -125,7 +126,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
125126

126127
#TODO: replace wrap_if where possible with in-line if - compare performance when doing it
127128
body = quote
128-
$loopoffset = (@blockIdx().z-1)*$loopsize #TODO: MOVE UP - see no perf change! interchange other lines!
129+
$loopoffset = (@blockIdx().z-1)*$loopsize + $range_z_start-1 #TODO: MOVE UP - see no perf change! interchange other lines!
129130
$((quote
130131
$tx = @threadIdx().x + $hx1
131132
$ty = @threadIdx().y + $hy1
@@ -164,9 +165,12 @@ $((:( $reg = 0.0
164165
# for $i = $loopstart:$(mainloopstart-1)
165166
$(wrap_loop(i, loopstart:mainloopstart-1,
166167
quote
167-
$tz_g = $i + $loopoffset
168-
if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
169-
$iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
168+
$iz = $i + $loopoffset
169+
if ($iz > $range_z_end) ParallelStencil.@return_nothing; end
170+
# NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges:
171+
# $tz_g = $i + $loopoffset
172+
# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
173+
# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
170174
$((wrap_if(:($i > $(loopentry-1)),
171175
:( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg
172176
)
@@ -212,9 +216,12 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers
212216
# for $i = $mainloopstart:$mainloopend # ParallelStencil.@unroll
213217
$(wrap_loop(i, mainloopstart:mainloopend,
214218
quote
215-
$tz_g = $i + $loopoffset
216-
if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
217-
$iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
219+
$iz = $i + $loopoffset
220+
if ($iz > $range_z_end) ParallelStencil.@return_nothing; end
221+
# NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges:
222+
# $tz_g = $i + $loopoffset
223+
# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
224+
# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
218225
$(use_any_shmem ?
219226
:( @sync_threads()
220227
) : NOEXPR
@@ -468,7 +475,7 @@ end
468475

469476
function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, body::Expr; package::Symbol=get_package(caller))
470477
loopdim = isa(indices,Expr) ? length(indices.args) : 1
471-
loopsize = LOOPSIZE
478+
loopsize = compute_loopsize(package)
472479
optranges = nothing
473480
use_shmemhalos = nothing
474481
optimize_halo_read = true
@@ -545,7 +552,8 @@ function remove_single_point_optvars(optvars, optranges_arg, offsets, offsets_by
545552
return tuple((A for A in optvars if !(length(keys(offsets[A]))==1 && length(keys(offsets_by_z[A]))==1) || (!isnothing(optranges_arg) && A keys(optranges_arg)))...)
546553
end
547554

548-
function define_optranges(optranges_arg, optvars, offsets, int_type)
555+
function define_optranges(optranges_arg, optvars, offsets, int_type, package)
556+
compute_capability = get_compute_capability(package)
549557
optranges = Dict()
550558
for A in optvars
551559
zspan_max = 0
@@ -560,12 +568,12 @@ function define_optranges(optranges_arg, optvars, offsets, int_type)
560568
fullrange = typemin(int_type):typemax(int_type)
561569
pointrange_x = oxy_zspan_max[1]: oxy_zspan_max[1]
562570
pointrange_y = oxy_zspan_max[2]: oxy_zspan_max[2]
563-
if (!isnothing(optranges_arg) && A keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A)
564-
elseif (length(optvars) <= FULLRANGE_THRESHOLD) optranges[A] = (fullrange, fullrange, fullrange)
565-
elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange)
566-
elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange)
567-
elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange)
568-
elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange)
571+
if (!isnothing(optranges_arg) && A keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A)
572+
elseif (compute_capability < v"8" && (length(optvars) <= FULLRANGE_THRESHOLD)) optranges[A] = (fullrange, fullrange, fullrange)
573+
elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange)
574+
elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange)
575+
elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange)
576+
elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange)
569577
end
570578
end
571579
return optranges

0 commit comments

Comments
 (0)