Skip to content

Commit 8f6ee45

Browse files
authored
Merge pull request #155 from omlins/nestedfunctions
Move macro expansions to begin
2 parents 3747b9e + ce34bf6 commit 8f6ee45

File tree

5 files changed

+44
-45
lines changed

5 files changed

+44
-45
lines changed

src/ParallelKernel/kernel_language.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,7 @@ end
283283

284284
function threads(caller::Module, args...; package::Symbol=get_package(caller))
285285
if (package == PKG_THREADS) return :(Base.Threads.@threads($(args...)))
286-
elseif (package == PKG_POLYESTER)
287-
args = macroexpand.((caller,), args)
288-
return :(Polyester.@batch($(args...)))
286+
elseif (package == PKG_POLYESTER) return :(Polyester.@batch($(args...)))
289287
elseif isgpu(package) return :(begin end)
290288
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
291289
end

src/ParallelKernel/parallel.jl

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -173,27 +173,6 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
173173
body = get_body(kernel)
174174
body = remove_return(body)
175175
use_aliases = !all(indices .== INDICES[1:length(indices)])
176-
if isgpu(package) kernel = insert_device_types(kernel) end
177-
kernel = push_to_signature!(kernel, :($RANGES_VARNAME::$RANGES_TYPE))
178-
if (package == PKG_CUDA) int_type = INT_CUDA
179-
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
180-
elseif (package == PKG_THREADS) int_type = INT_THREADS
181-
elseif (package == PKG_POLYESTER) int_type = INT_POLYESTER
182-
end
183-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[1])::$int_type))
184-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[2])::$int_type))
185-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[3])::$int_type))
186-
ranges = [:($RANGES_VARNAME[1]), :($RANGES_VARNAME[2]), :($RANGES_VARNAME[3])]
187-
if isgpu(package)
188-
body = add_threadids(indices, ranges, body)
189-
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
190-
body = literaltypes(int_type, body) # TODO: the size function always returns a 64 bit integer; the following is not performance efficient: body = cast(body, :size, int_type)
191-
elseif iscpu(package)
192-
body = add_loop(indices, ranges, body)
193-
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
194-
else
195-
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
196-
end
197176
if use_aliases # NOTE: we treat explicit parallel indices as aliases to the statically retrievable indices INDICES.
198177
indices_aliases = indices
199178
indices = [INDICES[1:length(indices)]...]
@@ -202,6 +181,9 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
202181
body = substitute(body, indices_aliases[i], indices[i])
203182
end
204183
end
184+
if isgpu(package) kernel = insert_device_types(kernel) end
185+
kernel = adjust_signatures(kernel, package)
186+
body = handle_indices_and_literals(body, indices, package, numbertype)
205187
if (inbounds) body = add_inbounds(body) end
206188
body = add_return(body)
207189
set_body!(kernel, body)
@@ -370,6 +352,34 @@ function literaltypes(type1::DataType, type2::DataType, expr::Expr)
370352
end
371353

372354

355+
## FUNCTIONS TO HANDLE SIGNATURES AND INDICES
356+
357+
function adjust_signatures(kernel::Expr, package::Symbol)
358+
int_type = kernel_int_type(package)
359+
kernel = push_to_signature!(kernel, :($RANGES_VARNAME::$RANGES_TYPE))
360+
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[1])::$int_type))
361+
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[2])::$int_type))
362+
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[3])::$int_type))
363+
return kernel
364+
end
365+
366+
function handle_indices_and_literals(body::Expr, indices::Array, package::Symbol, numbertype::DataType)
367+
int_type = kernel_int_type(package)
368+
ranges = [:($RANGES_VARNAME[1]), :($RANGES_VARNAME[2]), :($RANGES_VARNAME[3])]
369+
if isgpu(package)
370+
body = add_threadids(indices, ranges, body)
371+
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
372+
body = literaltypes(int_type, body) # TODO: the size function always returns a 64 bit integer; the following is not performance efficient: body = cast(body, :size, int_type)
373+
elseif iscpu(package)
374+
body = add_loop(indices, ranges, body)
375+
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
376+
else
377+
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
378+
end
379+
return body
380+
end
381+
382+
373383
## FUNCTIONS TO ADD THREAD-IDS / LOOPS IN KERNELS
374384

375385
function add_threadids(indices::Array, ranges::Array, block::Expr)

src/ParallelKernel/shared.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ end
7070
macro ranges() esc(RANGES_VARNAME) end
7171
macro rangelengths() esc(:(($(RANGELENGTHS_VARNAMES...),))) end
7272

73+
function kernel_int_type(package)
74+
if (package == PKG_CUDA) int_type = INT_CUDA
75+
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
76+
elseif (package == PKG_THREADS) int_type = INT_THREADS
77+
elseif (package == PKG_POLYESTER) int_type = INT_POLYESTER
78+
end
79+
return int_type
80+
end
81+
7382

7483
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT
7584

src/parallel.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -281,26 +281,8 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
281281
end
282282
if isgpu(package) kernel = insert_device_types(kernel) end
283283
if !memopt
284-
kernel = push_to_signature!(kernel, :($RANGES_VARNAME::$RANGES_TYPE))
285-
if (package == PKG_CUDA) int_type = INT_CUDA
286-
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
287-
elseif (package == PKG_THREADS) int_type = INT_THREADS
288-
elseif (package == PKG_POLYESTER) int_type = INT_POLYESTER
289-
end
290-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[1])::$int_type))
291-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[2])::$int_type))
292-
kernel = push_to_signature!(kernel, :($(RANGELENGTHS_VARNAMES[3])::$int_type))
293-
ranges = [:($RANGES_VARNAME[1]), :($RANGES_VARNAME[2]), :($RANGES_VARNAME[3])]
294-
if isgpu(package)
295-
body = add_threadids(indices, ranges, body)
296-
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
297-
body = literaltypes(int_type, body) # TODO: the size function always returns a 64 bit integer; the following is not performance efficient: body = cast(body, :size, int_type)
298-
elseif iscpu(package)
299-
body = add_loop(indices, ranges, body)
300-
body = (numbertype != NUMBERTYPE_NONE) ? literaltypes(numbertype, body) : body
301-
else
302-
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
303-
end
284+
kernel = adjust_signatures(kernel, package)
285+
body = handle_indices_and_literals(body, indices, package, numbertype)
304286
if (inbounds) body = add_inbounds(body) end
305287
end
306288
body = add_return(body)

src/shared.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import MacroTools: @capture, postwalk, splitdef, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
2-
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing
2+
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing
33
import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_POLYESTER, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
44
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring
55

0 commit comments

Comments
 (0)