From efa42c0e29b22dfad0a516c74a6255e702f844dd Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Thu, 31 Oct 2024 10:24:29 +0100 Subject: [PATCH] handel inverses --- src/ParallelKernel/parallel.jl | 13 ++++++++++++- src/parallel.jl | 1 + src/shared.jl | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/ParallelKernel/parallel.jl b/src/ParallelKernel/parallel.jl index 28a5af08..81b1e9c2 100644 --- a/src/ParallelKernel/parallel.jl +++ b/src/ParallelKernel/parallel.jl @@ -190,6 +190,7 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, if isgpu(package) kernel = insert_device_types(caller, kernel) end kernel = adjust_signatures(kernel, package) body = handle_padding(body, padding) # TODO: padding can later be made configurable per kernel (to enable working with arrays as before). + body = handle_inverses(body) body = handle_indices_and_literals(body, indices, package, numbertype) if (inbounds) body = add_inbounds(body) end body = add_return(body) @@ -361,7 +362,7 @@ function literaltypes(type1::DataType, type2::DataType, expr::Expr) end -## FUNCTIONS TO HANDLE SIGNATURES AND INDICES +## FUNCTIONS TO HANDLE SIGNATURES, INDICES, INVERSES AND PADDING function adjust_signatures(kernel::Expr, package::Symbol) int_type = kernel_int_type(package) @@ -372,6 +373,16 @@ function adjust_signatures(kernel::Expr, package::Symbol) return kernel end +function handle_inverses(body::Expr) + return postwalk(body) do ex + if @capture(ex, (1 | 1.0 | 1.0f0) / x_) + return :(inv($x)) + else + return ex + end + end +end + function handle_padding(body::Expr, padding::Bool) body = substitute_indices_inn(body, padding) if padding diff --git a/src/parallel.jl b/src/parallel.jl index 70f294da..85f9fe92 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -288,6 +288,7 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle if isgpu(package) kernel = insert_device_types(caller, kernel) end if !memopt kernel = adjust_signatures(kernel, package) + body = handle_inverses(body) body = handle_indices_and_literals(body, indices, package, numbertype) if (inbounds) body = add_inbounds(body) end end diff --git a/src/shared.jl b/src/shared.jl index 09c4cd1e..0b7d7ca8 100644 --- a/src/shared.jl +++ b/src/shared.jl @@ -1,5 +1,5 @@ import MacroTools: @capture, postwalk, splitdef, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr -import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing, @firstindex, @lastindex, is_access, find_vars, handle_padding +import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing, @firstindex, @lastindex, is_access, find_vars, handle_padding, handle_inverses import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_METAL, INT_POLYESTER, INT_THREADS, INDICES, INDICES_INN, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring, interpolate