Skip to content

Commit

Permalink
Merge pull request #182 from omlins/revise
Browse files Browse the repository at this point in the history
Ensure full compatibility with Revise.jl
  • Loading branch information
omlins authored Dec 11, 2024
2 parents 66a98d9 + 493374e commit 93fee7f
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 57 deletions.
53 changes: 37 additions & 16 deletions src/ParallelKernel/init_parallel_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,49 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
end


function Metadata_PK()
:(module $MOD_METADATA_PK # NOTE: there cannot be any newline before 'module $MOD_METADATA_PK' or it will create a begin end block and the module creation will fail.
let
global set_initialized, is_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding
_is_initialized::Bool = false
package::Symbol = $(quote_expr(PKG_NONE))
numbertype::DataType = $NUMBERTYPE_NONE
inbounds::Bool = $INBOUNDS_DEFAULT
padding::Bool = $PADDING_DEFAULT
set_initialized(flag::Bool) = (_is_initialized = flag)
is_initialized() = _is_initialized
set_package(pkg::Symbol) = (package = pkg)
get_package() = package
set_numbertype(T::DataType) = (numbertype = T)
get_numbertype() = numbertype
set_inbounds(flag::Bool) = (inbounds = flag)
get_inbounds() = inbounds
set_padding(flag::Bool) = (padding = flag)
get_padding() = padding
end
end)
end

createmeta_PK(caller::Module) = if !hasmeta_PK(caller) @eval(caller, $(Metadata_PK())) end


macro is_initialized() is_initialized(__module__) end
macro get_package() esc(get_package(__module__)) end # NOTE: escaping is required here, to avoid that the symbol is evaluated in this module, instead of just being returned as a symbol.
macro get_numbertype() get_numbertype(__module__) end
macro get_inbounds() get_inbounds(__module__) end
macro get_padding() get_padding(__module__) end
let
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding, check_initialized, check_already_initialized
_is_initialized::Dict{Module, Bool} = Dict{Module, Bool}()
package::Dict{Module, Symbol} = Dict{Module, Symbol}()
numbertype::Dict{Module, DataType} = Dict{Module, DataType}()
inbounds::Dict{Module, Bool} = Dict{Module, Bool}()
padding::Dict{Module, Bool} = Dict{Module, Bool}()
set_initialized(caller::Module, flag::Bool) = (_is_initialized[caller] = flag)
is_initialized(caller::Module) = haskey(_is_initialized, caller) && _is_initialized[caller]
set_package(caller::Module, pkg::Symbol) = (package[caller] = pkg)
get_package(caller::Module) = package[caller]
set_numbertype(caller::Module, T::DataType) = (numbertype[caller] = T)
get_numbertype(caller::Module) = numbertype[caller]
set_inbounds(caller::Module, flag::Bool) = (inbounds[caller] = flag)
get_inbounds(caller::Module) = inbounds[caller]
set_padding(caller::Module, flag::Bool) = (padding[caller] = flag)
get_padding(caller::Module) = padding[caller]
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding, check_initialized, check_already_initialized
set_initialized(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_initialized($flag)))
is_initialized(caller::Module) = hasmeta_PK(caller) && @eval(caller, $MOD_METADATA_PK.is_initialized())
set_package(caller::Module, pkg::Symbol) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_package($(quote_expr(pkg)))))
get_package(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_package()) : PKG_NONE
set_numbertype(caller::Module, T::DataType) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_numbertype($T)))
get_numbertype(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_numbertype()) : NUMBERTYPE_NONE
set_inbounds(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_inbounds($flag)))
get_inbounds(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_inbounds()) : INBOUNDS_DEFAULT
set_padding(caller::Module, flag::Bool) = (createmeta_PK(caller); @eval(caller, $MOD_METADATA_PK.set_padding($flag)))
get_padding(caller::Module) = hasmeta_PK(caller) ? @eval(caller, $MOD_METADATA_PK.get_padding()) : PADDING_DEFAULT
check_initialized(caller::Module) = if !is_initialized(caller) @NotInitializedError("no ParallelKernel macro or function can be called before @init_parallel_kernel in each module (missing call in $caller).") end
check_already_initialized(caller::Module) = if is_initialized(caller) @IncoherentCallError("ParallelKernel has already been initialized for the module $caller.") end
end
Expand Down
8 changes: 5 additions & 3 deletions src/ParallelKernel/reset_parallel_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ function reset_parallel_kernel(caller::Module)
tdata_module = TData_none()
@eval(caller, $tdata_module)
end
set_initialized(caller, false)
set_package(caller, PKG_NONE)
set_numbertype(caller, NUMBERTYPE_NONE)
if isdefined(caller, MOD_METADATA_PK)
set_initialized(caller, false)
set_package(caller, PKG_NONE)
set_numbertype(caller, NUMBERTYPE_NONE)
end
return nothing
end
13 changes: 10 additions & 3 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ gensym_world(tag::String, generator::Module) = gensym(string(tag, GENSYM_SEPARAT
gensym_world(tag::Symbol, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))
gensym_world(tag::Expr, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))

ixd(count) = @ModuleInternalError("function ixd had not be evaluated at parse time")
iyd(count) = @ModuleInternalError("function iyd had not be evaluated at parse time")
izd(count) = @ModuleInternalError("function izd had not be evaluated at parse time")
ixd(count) = @ModuleInternalError("function ixd had not been evaluated at parse time")
iyd(count) = @ModuleInternalError("function iyd had not been evaluated at parse time")
izd(count) = @ModuleInternalError("function izd had not been evaluated at parse time")

const MOD_METADATA_PK = gensym_world("__metadata_PK__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
const PKG_CUDA = :CUDA
const PKG_AMDGPU = :AMDGPU
const PKG_METAL = :Metal
Expand Down Expand Up @@ -53,6 +54,8 @@ const SUPPORTED_LITERALTYPES = [Float16, Float32, Float64, Complex{Fl
const SUPPORTED_NUMBERTYPES = [Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}]
const PKNumber = Union{Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}} # NOTE: this always needs to correspond to SUPPORTED_NUMBERTYPES!
const NUMBERTYPE_NONE = DataType
const INBOUNDS_DEFAULT = false
const PADDING_DEFAULT = false
const MODULENAME_DATA = :Data
const MODULENAME_TDATA = :TData
const MODULENAME_DEVICE = :Device
Expand Down Expand Up @@ -566,12 +569,16 @@ end

interpolate(sym::Symbol, vals_expr::Expr, block::Expr) = interpolate(sym, (extract_tuple(vals_expr)...,), block)

quote_expr(expr) = :($(Expr(:quote, expr)))


## FUNCTIONS/MACROS FOR DIVERSE SYNTAX SUGAR

iscpu(package) = return (package in (PKG_THREADS, PKG_POLYESTER))
isgpu(package) = return (package in (PKG_CUDA, PKG_AMDGPU, PKG_METAL))

hasmeta_PK(caller::Module) = isdefined(caller, MOD_METADATA_PK)


## TEMPORARY FUNCTION DEFINITIONS TO BE MERGED IN MACROTOOLS (https://github.com/FluxML/MacroTools.jl/pull/173)

Expand Down
75 changes: 51 additions & 24 deletions src/init_parallel_stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ function init_parallel_stencil(caller::Module, package::Symbol, numbertype::Data
end


function Metadata_PS()
:(module $MOD_METADATA_PS # NOTE: there cannot be any newline before 'module $MOD_METADATA_PS' or it will create a begin end block and the module creation will fail.
let
global set_initialized, is_initialized, set_package, get_package, set_numbertype, get_numbertype, set_ndims, get_ndims, set_inbounds, get_inbounds, set_padding, get_padding, set_memopt, get_memopt, set_nonconst_metadata, get_nonconst_metadata
_is_initialized::Bool = false
package::Symbol = $(quote_expr(PKG_NONE))
numbertype::DataType = $NUMBERTYPE_NONE
ndims::Integer = $NDIMS_NONE
inbounds::Bool = $INBOUNDS_DEFAULT
padding::Bool = $PADDING_DEFAULT
memopt::Bool = $MEMOPT_DEFAULT
nonconst_metadata::Bool = $NONCONST_METADATA_DEFAULT
set_initialized(flag::Bool) = (_is_initialized = flag)
is_initialized() = _is_initialized
set_package(pkg::Symbol) = (package = pkg)
get_package() = package
set_numbertype(T::DataType) = (numbertype = T)
get_numbertype() = numbertype
set_ndims(n::Integer) = (ndims = n)
get_ndims() = ndims
set_inbounds(flag::Bool) = (inbounds = flag)
get_inbounds() = inbounds
set_padding(flag::Bool) = (padding = flag)
get_padding() = padding
set_memopt(flag::Bool) = (memopt = flag)
get_memopt() = memopt
set_nonconst_metadata(flag::Bool) = (nonconst_metadata = flag)
get_nonconst_metadata() = nonconst_metadata
end
end)
end

createmeta_PS(caller::Module) = if !hasmeta_PS(caller) @eval(caller, $(Metadata_PS())) end


macro is_initialized() is_initialized(__module__) end
macro get_package() esc(get_package(__module__)) end # NOTE: escaping is required here, to avoid that the symbol is evaluated in this module, instead of just being returned as a symbol.
macro get_numbertype() get_numbertype(__module__) end
Expand All @@ -78,30 +113,22 @@ macro get_memopt() get_memopt(__module__) end
macro get_nonconst_metadata() get_nonconst_metadata(__module__) end
let
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_ndims, get_ndims, set_inbounds, get_inbounds, set_padding, get_padding, set_memopt, get_memopt, set_nonconst_metadata, get_nonconst_metadata, check_initialized, check_already_initialized
_is_initialized::Dict{Module, Bool} = Dict{Module, Bool}()
package::Dict{Module, Symbol} = Dict{Module, Symbol}()
numbertype::Dict{Module, DataType} = Dict{Module, DataType}()
ndims::Dict{Module, Integer} = Dict{Module, Integer}()
inbounds::Dict{Module, Bool} = Dict{Module, Bool}()
padding::Dict{Module, Bool} = Dict{Module, Bool}()
memopt::Dict{Module, Bool} = Dict{Module, Bool}()
nonconst_metadata::Dict{Module, Bool} = Dict{Module, Bool}()
set_initialized(caller::Module, flag::Bool) = (_is_initialized[caller] = flag)
is_initialized(caller::Module) = haskey(_is_initialized, caller) && _is_initialized[caller]
set_package(caller::Module, pkg::Symbol) = (package[caller] = pkg)
get_package(caller::Module) = package[caller]
set_numbertype(caller::Module, T::DataType) = (numbertype[caller] = T)
get_numbertype(caller::Module) = numbertype[caller]
set_ndims(caller::Module, n::Integer) = (ndims[caller] = n)
get_ndims(caller::Module) = ndims[caller]
set_inbounds(caller::Module, flag::Bool) = (inbounds[caller] = flag)
get_inbounds(caller::Module) = inbounds[caller]
set_padding(caller::Module, flag::Bool) = (padding[caller] = flag)
get_padding(caller::Module) = padding[caller]
set_memopt(caller::Module, flag::Bool) = (memopt[caller] = flag)
get_memopt(caller::Module) = memopt[caller]
set_nonconst_metadata(caller::Module, flag::Bool) = (nonconst_metadata[caller] = flag)
get_nonconst_metadata(caller::Module) = nonconst_metadata[caller]
set_initialized(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_initialized($flag)))
is_initialized(caller::Module) = hasmeta_PS(caller) && @eval(caller, $MOD_METADATA_PS.is_initialized())
set_package(caller::Module, pkg::Symbol) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_package($(quote_expr(pkg)))))
get_package(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_package()) : PKG_NONE
set_numbertype(caller::Module, T::DataType) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_numbertype($T)))
get_numbertype(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_numbertype()) : NUMBERTYPE_NONE
set_ndims(caller::Module, n::Integer) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_ndims($n)))
get_ndims(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_ndims()) : NDIMS_NONE
set_inbounds(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_inbounds($flag)))
get_inbounds(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_inbounds()) : INBOUNDS_DEFAULT
set_padding(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_padding($flag)))
get_padding(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_padding()) : PADDING_DEFAULT
set_memopt(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_memopt($flag)))
get_memopt(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_memopt()) : MEMOPT_DEFAULT
set_nonconst_metadata(caller::Module, flag::Bool) = (createmeta_PS(caller); @eval(caller, $MOD_METADATA_PS.set_nonconst_metadata($flag)))
get_nonconst_metadata(caller::Module) = hasmeta_PS(caller) ? @eval(caller, $MOD_METADATA_PS.get_nonconst_metadata()) : NONCONST_METADATA_DEFAULT
check_initialized(caller::Module) = if !is_initialized(caller) @NotInitializedError("no ParallelStencil macro or function can be called before @init_parallel_stencil in each module (missing call in $caller).") end

function check_already_initialized(caller::Module, package::Symbol, numbertype::DataType, ndims::Integer, inbounds::Bool, padding::Bool, memopt::Bool, nonconst_metadata::Bool)
Expand Down
4 changes: 2 additions & 2 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@ end

function create_metadata_storage(source::LineNumberNode, caller::Module, kernel::Expr)
kernelid = get_kernelid(get_name(kernel), source.file, source.line)
create_module(caller, MOD_METADATA)
topmodule = @eval(caller, $MOD_METADATA)
create_module(caller, MOD_METADATA_PS)
topmodule = @eval(caller, $MOD_METADATA_PS)
create_module(topmodule, kernelid)
metadata_module = @eval(topmodule, $kernelid)
metadata_function = create_metadata_function(kernel, metadata_module)
Expand Down
10 changes: 6 additions & 4 deletions src/reset_parallel_stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ macro reset_parallel_stencil() esc(reset_parallel_stencil(__module__)) end

function reset_parallel_stencil(caller::Module)
ParallelKernel.reset_parallel_kernel(caller)
set_initialized(caller, false)
set_package(caller, PKG_NONE)
set_numbertype(caller, NUMBERTYPE_NONE)
set_ndims(caller, NDIMS_NONE)
if isdefined(caller, MOD_METADATA_PS)
set_initialized(caller, false)
set_package(caller, PKG_NONE)
set_numbertype(caller, NUMBERTYPE_NONE)
set_ndims(caller, NDIMS_NONE)
end
return nothing
end
Loading

0 comments on commit 93fee7f

Please sign in to comment.