Skip to content

Commit

Permalink
handel inverses
Browse files Browse the repository at this point in the history
  • Loading branch information
omlins committed Oct 31, 2024
1 parent 1a66d40 commit efa42c0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
13 changes: 12 additions & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
if isgpu(package) kernel = insert_device_types(caller, kernel) end
kernel = adjust_signatures(kernel, package)
body = handle_padding(body, padding) # TODO: padding can later be made configurable per kernel (to enable working with arrays as before).
body = handle_inverses(body)
body = handle_indices_and_literals(body, indices, package, numbertype)
if (inbounds) body = add_inbounds(body) end
body = add_return(body)
Expand Down Expand Up @@ -361,7 +362,7 @@ function literaltypes(type1::DataType, type2::DataType, expr::Expr)
end


## FUNCTIONS TO HANDLE SIGNATURES AND INDICES
## FUNCTIONS TO HANDLE SIGNATURES, INDICES, INVERSES AND PADDING

function adjust_signatures(kernel::Expr, package::Symbol)
int_type = kernel_int_type(package)
Expand All @@ -372,6 +373,16 @@ function adjust_signatures(kernel::Expr, package::Symbol)
return kernel
end

function handle_inverses(body::Expr)
return postwalk(body) do ex
if @capture(ex, (1 | 1.0 | 1.0f0) / x_)
return :(inv($x))
else
return ex
end
end
end

function handle_padding(body::Expr, padding::Bool)
body = substitute_indices_inn(body, padding)
if padding
Expand Down
1 change: 1 addition & 0 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
if isgpu(package) kernel = insert_device_types(caller, kernel) end
if !memopt
kernel = adjust_signatures(kernel, package)
body = handle_inverses(body)
body = handle_indices_and_literals(body, indices, package, numbertype)
if (inbounds) body = add_inbounds(body) end
end
Expand Down
2 changes: 1 addition & 1 deletion src/shared.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import MacroTools: @capture, postwalk, splitdef, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing, @firstindex, @lastindex, is_access, find_vars, handle_padding
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing, @firstindex, @lastindex, is_access, find_vars, handle_padding, handle_inverses
import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_METAL, INT_POLYESTER, INT_THREADS, INDICES, INDICES_INN, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring, interpolate

Expand Down

0 comments on commit efa42c0

Please sign in to comment.