Skip to content

Commit 93fee7f

Browse files
authored
Merge pull request #182 from omlins/revise
Ensure full compatibility with Revise.jl
2 parents 66a98d9 + 493374e commit 93fee7f

14 files changed

+200
-57
lines changed

src/ParallelKernel/init_parallel_kernel.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,28 +87,49 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
8787
end
8888

8989

90+
function Metadata_PK()
91+
:(module $MOD_METADATA_PK # NOTE: there cannot be any newline before 'module $MOD_METADATA_PK' or it will create a begin end block and the module creation will fail.
92+
let
93+
global set_initialized, is_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding
94+
_is_initialized::Bool = false
95+
package::Symbol = $(quote_expr(PKG_NONE))
96+
numbertype::DataType = $NUMBERTYPE_NONE
97+
inbounds::Bool = $INBOUNDS_DEFAULT
98+
padding::Bool = $PADDING_DEFAULT
99+
set_initialized(flag::Bool) = (_is_initialized = flag)
100+
is_initialized() = _is_initialized
101+
set_package(pkg::Symbol) = (package = pkg)
102+
get_package() = package
103+
set_numbertype(T::DataType) = (numbertype = T)
104+
get_numbertype() = numbertype
105+
set_inbounds(flag::Bool) = (inbounds = flag)
106+
get_inbounds() = inbounds
107+
set_padding(flag::Bool) = (padding = flag)
108+
get_padding() = padding
109+
end
110+
end)
111+
end
112+
113+
createmeta_PK(caller::Module) = if !hasmeta_PK(caller) @eval(caller, $(Metadata_PK())) end
114+
115+
90116
macro is_initialized() is_initialized(__module__) end
91117
macro get_package() esc(get_package(__module__)) end # NOTE: escaping is required here, to avoid that the symbol is evaluated in this module, instead of just being returned as a symbol.
92118
macro get_numbertype() get_numbertype(__module__) end
93119
macro get_inbounds() get_inbounds(__module__) end
94120
macro get_padding() get_padding(__module__) end
95121
let
96-
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding, check_initialized, check_already_initialized
97-
_is_initialized::Dict{Module, Bool} = Dict{Module, Bool}()
98-
package::Dict{Module, Symbol} = Dict{Module, Symbol}()
99-
numbertype::Dict{Module, DataType} = Dict{Module, DataType}()
100-
inbounds::Dict{Module, Bool} = Dict{Module, Bool}()
101-
padding::Dict{Module, Bool} = Dict{Module, Bool}()
102-
set_initialized(caller::Module, flag::Bool) = (_is_initialized[caller] = flag)
103-
is_initialized(caller::Module) = haskey(_is_initialized, caller) && _is_initialized[caller]
104-
set_package(caller::Module, pkg::Symbol) = (package[caller] = pkg)
105-
get_package(caller::Module) = package[caller]
106-
set_numbertype(caller::Module, T::DataType) = (numbertype[caller] = T)
107-
get_numbertype(caller::Module) = numbertype[caller]
108-
set_inbounds(caller::Module, flag::Bool) = (inbounds[caller] = flag)
109-
get_inbounds(caller::Module) = inbounds[caller]
110-
set_padding(caller::Module, flag::Bool) = (padding[caller] = flag)
111-
get_padding(caller::Module) = padding[caller]
122+
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding, check_initialized, check_already_initialized
123+
set_initialized(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_initialized($flag)))
124+
is_initialized(caller::Module) = hasmeta_PK(caller) && @eval(caller, $MOD_METADATA_PK.is_initialized())
125+
set_package(caller::Module, pkg::Symbol) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_package($(quote_expr(pkg)))))
126+
get_package(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_package()) : PKG_NONE
127+
set_numbertype(caller::Module, T::DataType) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_numbertype($T)))
128+
get_numbertype(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_numbertype()) : NUMBERTYPE_NONE
129+
set_inbounds(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_inbounds($flag)))
130+
get_inbounds(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_inbounds()) : INBOUNDS_DEFAULT
131+
set_padding(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_padding($flag)))
132+
get_padding(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_padding()) : PADDING_DEFAULT
112133
check_initialized(caller::Module) = if !is_initialized(caller) @NotInitializedError("no ParallelKernel macro or function can be called before @init_parallel_kernel in each module (missing call in $caller).") end
113134
check_already_initialized(caller::Module) = if is_initialized(caller) @IncoherentCallError("ParallelKernel has already been initialized for the module $caller.") end
114135
end

src/ParallelKernel/reset_parallel_kernel.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ function reset_parallel_kernel(caller::Module)
1616
tdata_module = TData_none()
1717
@eval(caller, $tdata_module)
1818
end
19-
set_initialized(caller, false)
20-
set_package(caller, PKG_NONE)
21-
set_numbertype(caller, NUMBERTYPE_NONE)
19+
if isdefined(caller, MOD_METADATA_PK)
20+
set_initialized(caller, false)
21+
set_package(caller, PKG_NONE)
22+
set_numbertype(caller, NUMBERTYPE_NONE)
23+
end
2224
return nothing
2325
end

src/ParallelKernel/shared.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ gensym_world(tag::String, generator::Module) = gensym(string(tag, GENSYM_SEPARAT
99
gensym_world(tag::Symbol, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))
1010
gensym_world(tag::Expr, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))
1111

12-
ixd(count) = @ModuleInternalError("function ixd had not be evaluated at parse time")
13-
iyd(count) = @ModuleInternalError("function iyd had not be evaluated at parse time")
14-
izd(count) = @ModuleInternalError("function izd had not be evaluated at parse time")
12+
ixd(count) = @ModuleInternalError("function ixd had not been evaluated at parse time")
13+
iyd(count) = @ModuleInternalError("function iyd had not been evaluated at parse time")
14+
izd(count) = @ModuleInternalError("function izd had not been evaluated at parse time")
1515

16+
const MOD_METADATA_PK = gensym_world("__metadata_PK__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
1617
const PKG_CUDA = :CUDA
1718
const PKG_AMDGPU = :AMDGPU
1819
const PKG_METAL = :Metal
@@ -53,6 +54,8 @@ const SUPPORTED_LITERALTYPES = [Float16, Float32, Float64, Complex{Fl
5354
const SUPPORTED_NUMBERTYPES = [Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}]
5455
const PKNumber = Union{Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}} # NOTE: this always needs to correspond to SUPPORTED_NUMBERTYPES!
5556
const NUMBERTYPE_NONE = DataType
57+
const INBOUNDS_DEFAULT = false
58+
const PADDING_DEFAULT = false
5659
const MODULENAME_DATA = :Data
5760
const MODULENAME_TDATA = :TData
5861
const MODULENAME_DEVICE = :Device
@@ -566,12 +569,16 @@ end
566569

567570
interpolate(sym::Symbol, vals_expr::Expr, block::Expr) = interpolate(sym, (extract_tuple(vals_expr)...,), block)
568571

572+
quote_expr(expr) = :($(Expr(:quote, expr)))
573+
569574

570575
## FUNCTIONS/MACROS FOR DIVERSE SYNTAX SUGAR
571576

572577
iscpu(package) = return (package in (PKG_THREADS, PKG_POLYESTER))
573578
isgpu(package) = return (package in (PKG_CUDA, PKG_AMDGPU, PKG_METAL))
574579

580+
hasmeta_PK(caller::Module) = isdefined(caller, MOD_METADATA_PK)
581+
575582

576583
## TEMPORARY FUNCTION DEFINITIONS TO BE MERGED IN MACROTOOLS (https://github.com/FluxML/MacroTools.jl/pull/173)
577584

src/init_parallel_stencil.jl

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,41 @@ function init_parallel_stencil(caller::Module, package::Symbol, numbertype::Data
6868
end
6969

7070

71+
function Metadata_PS()
72+
:(module $MOD_METADATA_PS # NOTE: there cannot be any newline before 'module $MOD_METADATA_PS' or it will create a begin end block and the module creation will fail.
73+
let
74+
global set_initialized, is_initialized, set_package, get_package, set_numbertype, get_numbertype, set_ndims, get_ndims, set_inbounds, get_inbounds, set_padding, get_padding, set_memopt, get_memopt, set_nonconst_metadata, get_nonconst_metadata
75+
_is_initialized::Bool = false
76+
package::Symbol = $(quote_expr(PKG_NONE))
77+
numbertype::DataType = $NUMBERTYPE_NONE
78+
ndims::Integer = $NDIMS_NONE
79+
inbounds::Bool = $INBOUNDS_DEFAULT
80+
padding::Bool = $PADDING_DEFAULT
81+
memopt::Bool = $MEMOPT_DEFAULT
82+
nonconst_metadata::Bool = $NONCONST_METADATA_DEFAULT
83+
set_initialized(flag::Bool) = (_is_initialized = flag)
84+
is_initialized() = _is_initialized
85+
set_package(pkg::Symbol) = (package = pkg)
86+
get_package() = package
87+
set_numbertype(T::DataType) = (numbertype = T)
88+
get_numbertype() = numbertype
89+
set_ndims(n::Integer) = (ndims = n)
90+
get_ndims() = ndims
91+
set_inbounds(flag::Bool) = (inbounds = flag)
92+
get_inbounds() = inbounds
93+
set_padding(flag::Bool) = (padding = flag)
94+
get_padding() = padding
95+
set_memopt(flag::Bool) = (memopt = flag)
96+
get_memopt() = memopt
97+
set_nonconst_metadata(flag::Bool) = (nonconst_metadata = flag)
98+
get_nonconst_metadata() = nonconst_metadata
99+
end
100+
end)
101+
end
102+
103+
createmeta_PS(caller::Module) = if !hasmeta_PS(caller) @eval(caller, $(Metadata_PS())) end
104+
105+
71106
macro is_initialized() is_initialized(__module__) end
72107
macro get_package() esc(get_package(__module__)) end # NOTE: escaping is required here, to avoid that the symbol is evaluated in this module, instead of just being returned as a symbol.
73108
macro get_numbertype() get_numbertype(__module__) end
@@ -78,30 +113,22 @@ macro get_memopt() get_memopt(__module__) end
78113
macro get_nonconst_metadata() get_nonconst_metadata(__module__) end
79114
let
80115
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_ndims, get_ndims, set_inbounds, get_inbounds, set_padding, get_padding, set_memopt, get_memopt, set_nonconst_metadata, get_nonconst_metadata, check_initialized, check_already_initialized
81-
_is_initialized::Dict{Module, Bool} = Dict{Module, Bool}()
82-
package::Dict{Module, Symbol} = Dict{Module, Symbol}()
83-
numbertype::Dict{Module, DataType} = Dict{Module, DataType}()
84-
ndims::Dict{Module, Integer} = Dict{Module, Integer}()
85-
inbounds::Dict{Module, Bool} = Dict{Module, Bool}()
86-
padding::Dict{Module, Bool} = Dict{Module, Bool}()
87-
memopt::Dict{Module, Bool} = Dict{Module, Bool}()
88-
nonconst_metadata::Dict{Module, Bool} = Dict{Module, Bool}()
89-
set_initialized(caller::Module, flag::Bool) = (_is_initialized[caller] = flag)
90-
is_initialized(caller::Module) = haskey(_is_initialized, caller) && _is_initialized[caller]
91-
set_package(caller::Module, pkg::Symbol) = (package[caller] = pkg)
92-
get_package(caller::Module) = package[caller]
93-
set_numbertype(caller::Module, T::DataType) = (numbertype[caller] = T)
94-
get_numbertype(caller::Module) = numbertype[caller]
95-
set_ndims(caller::Module, n::Integer) = (ndims[caller] = n)
96-
get_ndims(caller::Module) = ndims[caller]
97-
set_inbounds(caller::Module, flag::Bool) = (inbounds[caller] = flag)
98-
get_inbounds(caller::Module) = inbounds[caller]
99-
set_padding(caller::Module, flag::Bool) = (padding[caller] = flag)
100-
get_padding(caller::Module) = padding[caller]
101-
set_memopt(caller::Module, flag::Bool) = (memopt[caller] = flag)
102-
get_memopt(caller::Module) = memopt[caller]
103-
set_nonconst_metadata(caller::Module, flag::Bool) = (nonconst_metadata[caller] = flag)
104-
get_nonconst_metadata(caller::Module) = nonconst_metadata[caller]
116+
set_initialized(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_initialized($flag)))
117+
is_initialized(caller::Module) = hasmeta_PS(caller) && @eval(caller, $MOD_METADATA_PS.is_initialized())
118+
set_package(caller::Module, pkg::Symbol) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_package($(quote_expr(pkg)))))
119+
get_package(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_package()) : PKG_NONE
120+
set_numbertype(caller::Module, T::DataType) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_numbertype($T)))
121+
get_numbertype(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_numbertype()) : NUMBERTYPE_NONE
122+
set_ndims(caller::Module, n::Integer) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_ndims($n)))
123+
get_ndims(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_ndims()) : NDIMS_NONE
124+
set_inbounds(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_inbounds($flag)))
125+
get_inbounds(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_inbounds()) : INBOUNDS_DEFAULT
126+
set_padding(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_padding($flag)))
127+
get_padding(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_padding()) : PADDING_DEFAULT
128+
set_memopt(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_memopt($flag)))
129+
get_memopt(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_memopt()) : MEMOPT_DEFAULT
130+
set_nonconst_metadata(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_nonconst_metadata($flag)))
131+
get_nonconst_metadata(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_nonconst_metadata()) : NONCONST_METADATA_DEFAULT
105132
check_initialized(caller::Module) = if !is_initialized(caller) @NotInitializedError("no ParallelStencil macro or function can be called before @init_parallel_stencil in each module (missing call in $caller).") end
106133

107134
function check_already_initialized(caller::Module, package::Symbol, numbertype::DataType, ndims::Integer, inbounds::Bool, padding::Bool, memopt::Bool, nonconst_metadata::Bool)

src/parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,8 @@ end
462462

463463
function create_metadata_storage(source::LineNumberNode, caller::Module, kernel::Expr)
464464
kernelid = get_kernelid(get_name(kernel), source.file, source.line)
465-
create_module(caller, MOD_METADATA)
466-
topmodule = @eval(caller, $MOD_METADATA)
465+
create_module(caller, MOD_METADATA_PS)
466+
topmodule = @eval(caller, $MOD_METADATA_PS)
467467
create_module(topmodule, kernelid)
468468
metadata_module = @eval(topmodule, $kernelid)
469469
metadata_function = create_metadata_function(kernel, metadata_module)

src/reset_parallel_stencil.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ macro reset_parallel_stencil() esc(reset_parallel_stencil(__module__)) end
99

1010
function reset_parallel_stencil(caller::Module)
1111
ParallelKernel.reset_parallel_kernel(caller)
12-
set_initialized(caller, false)
13-
set_package(caller, PKG_NONE)
14-
set_numbertype(caller, NUMBERTYPE_NONE)
15-
set_ndims(caller, NDIMS_NONE)
12+
if isdefined(caller, MOD_METADATA_PS)
13+
set_initialized(caller, false)
14+
set_package(caller, PKG_NONE)
15+
set_numbertype(caller, NUMBERTYPE_NONE)
16+
set_ndims(caller, NDIMS_NONE)
17+
end
1618
return nothing
1719
end

0 commit comments

Comments
 (0)