Skip to content

Commit 6a549bf

Browse files
authored
Merge pull request #183 from omlins/data
Fix Data module type replacement in interactive session
2 parents 1509d25 + 16b6e83 commit 6a549bf

File tree

6 files changed

+31
-31
lines changed

6 files changed

+31
-31
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
with:
2828
version: ${{ matrix.version }}
2929
arch: ${{ matrix.arch }}
30-
- uses: actions/cache@v1
30+
- uses: actions/cache@v4
3131
env:
3232
cache-name: cache-artifacts
3333
with:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ParallelStencil_MetalExt = "Metal"
2525
[compat]
2626
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
2727
CUDA = "3.12, 4, 5"
28-
CellArrays = "0.3"
28+
CellArrays = "0.3.2"
2929
Enzyme = "0.12, 0.13"
3030
MacroTools = "0.5"
3131
Metal = "1.2"

src/ParallelKernel/CUDAExt/allocators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ParallelStencil.ParallelKernel.fill_cuda(::Type{T}, blocklength, args...) where
99

1010
ParallelStencil.ParallelKernel.zeros_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 0, args...))
1111
ParallelStencil.ParallelKernel.ones_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 1, args...))
12-
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_cuda(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, CUDA.CuArray{eltype(T),3}}(CUDA.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims))
12+
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_cuda(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CuCellArray{T,length(dims),B, eltype(T)}(CUDA.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims))
1313
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_cuda(T, blocklength, dims)
1414
ParallelStencil.ParallelKernel.falses_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, false, args...)
1515
ParallelStencil.ParallelKernel.trues_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, true, args...)

src/ParallelKernel/parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ function simplify_conditions(caller::Module, expr::Expr)
388388
elseif (@capture(cond, a_ <= ixyz_ - c_ <= b_) && ixyz in INDICES) cond = :($a + $c <= $ixyz <= $b + $c)
389389
end
390390
if @capture(cond, a_ < x_ < b_) || @capture(cond, a_ < x_ <= b_) || @capture(cond, a_ <= x_ < b_) || @capture(cond, a_ <= x_ <= b_)
391-
a_val = eval_try(caller, a)
392-
b_val = eval_try(caller, b)
391+
a_val = eval_try(caller, a; when_interactive=false)
392+
b_val = eval_try(caller, b; when_interactive=false)
393393
if !isnothing(a_val) cond = substitute(cond, a, :($a_val), inQuoteNode=true) end
394394
if !isnothing(b_val) cond = substitute(cond, b, :($b_val), inQuoteNode=true) end
395395
end

src/ParallelKernel/shared.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,8 @@ function eval_arg(caller::Module, arg)
377377
end
378378
end
379379

380-
function eval_try(caller::Module, expr)
381-
if isinteractive() # NOTE: this is required to avoid that this function returns non-constant values in interactive sessions.
380+
function eval_try(caller::Module, expr; when_interactive::Bool=true)
381+
if !when_interactive && isinteractive() # NOTE: this is required to avoid that this function returns non-constant values in interactive sessions, when not appropriate (e.g. in for optimization)
382382
return nothing
383383
else
384384
try
@@ -562,9 +562,9 @@ end
562562

563563

564564
function interpolate(sym::Symbol, vals::NTuple, block::Expr)
565-
return quote
565+
return flatten(unblock(quote
566566
$((substitute(block, sym, val; inQuoteNode=true, inString=true) for val in vals)...)
567-
end
567+
end))
568568
end
569569

570570
interpolate(sym::Symbol, vals_expr::Expr, block::Expr) = interpolate(sym, (extract_tuple(vals_expr)...,), block)

test/ParallelKernel/test_parallel.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ eval(:(
138138
@test Array(Ā) Ā_ref
139139
@test Array(B̄) B̄_ref
140140
end
141-
end
141+
end;
142142
@testset "@parallel_indices" begin
143143
@testset "inbounds" begin
144144
expansion = @prettystring(1, @parallel_indices (ix) inbounds=true f(A) = (2*A; return))
@@ -147,27 +147,27 @@ eval(:(
147147
@test !occursin("Base.@inbounds begin", expansion)
148148
expansion = @prettystring(1, @parallel_indices (ix) f(A) = (2*A; return))
149149
@test !occursin("Base.@inbounds begin", expansion)
150-
end
150+
end;
151151
@testset "addition of range arguments" begin
152152
expansion = @gorgeousstring(1, @parallel_indices (ix,iy) f(a::T, b::T) where T <: Union{Array{Float32}, Array{Float64}} = (println("a=$a, b=$b)"); return))
153153
@test occursin("f(a::T, b::T, ranges::Tuple{UnitRange, UnitRange, UnitRange}, rangelength_x::Int64, rangelength_y::Int64, rangelength_z::Int64", expansion)
154-
end
155-
$(interpolate(:__T__, ARRAYTYPES, :(
154+
end;
155+
@testset "Data.T to Data.Device.T" $(interpolate(:__T__, ARRAYTYPES, :(
156156
@testset "Data.__T__ to Data.Device.__T__" begin
157157
@static if @isgpu($package)
158158
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::Data.__T__, B::Data.__T__, c::T) where T <: Integer = (A[ix,iy] = B[ix,iy]^c; return))
159159
@test occursin("f(A::Data.Device.__T__, B::Data.Device.__T__,", expansion)
160160
end
161-
end
162-
)))
163-
$(interpolate(:__T__, FIELDTYPES, :(
161+
end;
162+
)));
163+
@testset "Data.Fields.T to Data.Fields.Device.T" $(interpolate(:__T__, FIELDTYPES, :(
164164
@testset "Data.Fields.__T__ to Data.Fields.Device.__T__" begin
165165
@static if @isgpu($package)
166166
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::Data.Fields.__T__, B::Data.Fields.__T__, c::T) where T <: Integer = (A[ix,iy] = B[ix,iy]^c; return))
167167
@test occursin("f(A::Data.Fields.Device.__T__, B::Data.Fields.Device.__T__,", expansion)
168168
end
169-
end
170-
)))
169+
end;
170+
)));
171171
# NOTE: the following GPU tests fail, because the Fields module cannot be imported.
172172
# @testset "Fields.Field to Data.Fields.Device.Field" begin
173173
# @static if @isgpu($package)
@@ -183,22 +183,22 @@ eval(:(
183183
# @test occursin("f(A::Data.Fields.Device.Field, B::Data.Fields.Device.Field,", expansion)
184184
# end
185185
# end
186-
$(interpolate(:__T__, ARRAYTYPES, :(
186+
@testset "TData.T to TData.Device.T" $(interpolate(:__T__, ARRAYTYPES, :(
187187
@testset "TData.__T__ to TData.Device.__T__" begin
188188
@static if @isgpu($package)
189189
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::TData.__T__, B::TData.__T__, c::T) where T <: Integer = (A[ix,iy] = B[ix,iy]^c; return))
190190
@test occursin("f(A::TData.Device.__T__, B::TData.Device.__T__,", expansion)
191191
end
192-
end
193-
)))
194-
$(interpolate(:__T__, FIELDTYPES, :(
192+
end;
193+
)));
194+
@testset "TData.Fields.T to TData.Fields.Device.T" $(interpolate(:__T__, FIELDTYPES, :(
195195
@testset "TData.Fields.__T__ to TData.Fields.Device.__T__" begin
196196
@static if @isgpu($package)
197197
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::TData.Fields.__T__, B::TData.Fields.__T__, c::T) where T <: Integer = (A[ix,iy] = B[ix,iy]^c; return))
198198
@test occursin("f(A::TData.Fields.Device.__T__, B::TData.Fields.Device.__T__,", expansion)
199199
end
200-
end
201-
)))
200+
end;
201+
)));
202202
# NOTE: the following GPU tests fail, because the Fields module cannot be imported.
203203
# @testset "Fields.Field to TData.Fields.Device.Field" begin
204204
# @static if @isgpu($package)
@@ -214,14 +214,14 @@ eval(:(
214214
# @test occursin("f(A::TData.Fields.Device.Field, B::TData.Fields.Device.Field,", expansion)
215215
# end
216216
# end
217-
$(interpolate(:__T__, ARRAYTYPES, :(
217+
@testset "Nested Data.T to Data.Device.T" $(interpolate(:__T__, ARRAYTYPES, :(
218218
@testset "Nested Data.__T__ to Data.Device.__T__" begin
219219
@static if @isgpu($package)
220220
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::NamedTuple{T1, NTuple{T2,T3}} where {T1,T2} where T3 <: Data.__T__, c::T) where T <: Integer = (A[ix,iy] = B[ix,iy]^c; return))
221221
@test occursin("f(A::((NamedTuple{T1, NTuple{T2, T3}} where {T1, T2}) where T3 <: Data.Device.__T__),", expansion)
222222
end
223-
end
224-
)))
223+
end;
224+
)));
225225
@testset "@parallel_indices (1D)" begin
226226
A = @zeros(4)
227227
@parallel_indices (ix) function write_indices!(A)
@@ -422,22 +422,22 @@ eval(:(
422422
@require !@is_initialized()
423423
@init_parallel_kernel(package = $package)
424424
@require @is_initialized
425-
$(interpolate(:__T__, ARRAYTYPES, :(
425+
@testset "Data.T{T2} to Data.Device.T{T2}" $(interpolate(:__T__, ARRAYTYPES, :(
426426
@testset "Data.__T__{T2} to Data.Device.__T__{T2}" begin
427427
@static if @isgpu($package)
428428
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::Data.__T__{T2}, B::Data.__T__{T2}, c<:Integer) where T2 <: Union{Float32, Float64} = (A[ix,iy] = B[ix,iy]^c; return))
429429
@test occursin("f(A::Data.Device.__T__{T2}, B::Data.Device.__T__{T2},", expansion)
430430
end
431431
end;
432-
)))
433-
$(interpolate(:__T__, FIELDTYPES, :(
432+
)));
433+
@testset "Data.Fields.T{T2} to Data.Fields.Device.T{T2}" $(interpolate(:__T__, FIELDTYPES, :(
434434
@testset "Data.Fields.__T__{T2} to Data.Fields.Device.__T__{T2}" begin
435435
@static if @isgpu($package)
436436
expansion = @prettystring(1, @parallel_indices (ix,iy) f(A::Data.Fields.__T__{T2}, B::Data.Fields.__T__{T2}, c<:Integer) where T2 <: Union{Float32, Float64} = (A[ix,iy] = B[ix,iy]^c; return))
437437
@test occursin("f(A::Data.Fields.Device.__T__{T2}, B::Data.Fields.Device.__T__{T2},", expansion)
438438
end
439439
end;
440-
)))
440+
)));
441441
@reset_parallel_kernel()
442442
end;
443443
@testset "5. Exceptions" begin

0 commit comments

Comments
 (0)