diff --git a/Project.toml b/Project.toml
index c5c62a8f..36d6396f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -13,12 +13,14 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
[extensions]
ParallelStencil_AMDGPUExt = "AMDGPU"
ParallelStencil_CUDAExt = "CUDA"
ParallelStencil_EnzymeExt = "Enzyme"
+ParallelStencil_MetalExt = "Metal"
[compat]
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
@@ -26,6 +28,7 @@ CUDA = "3.12, 4, 5"
CellArrays = "0.3"
Enzyme = "0.11, 0.12, 0.13"
MacroTools = "0.5"
+Metal = "1.2"
Polyester = "0.7"
StaticArrays = "1"
julia = "1.10" # Minimum version supporting Data module creation
@@ -35,4 +38,4 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["Test", "TOML", "AMDGPU", "CUDA", "Enzyme", "Polyester"]
+test = ["Test", "TOML", "AMDGPU", "CUDA", "Metal", "Enzyme", "Polyester"]
diff --git a/README.md b/README.md
index 4d3f20cc..6e17abd5 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@ ParallelStencil empowers domain scientists to write architecture-agnostic high-l
![Performance ParallelStencil Teff](docs/images/perf_ps2.png)
-ParallelStencil relies on the native kernel programming capabilities of [CUDA.jl] and [AMDGPU.jl] and on [Base.Threads] for high-performance computations on GPUs and CPUs, respectively. It is seamlessly interoperable with [ImplicitGlobalGrid.jl], which renders the distributed parallelization of stencil-based GPU and CPU applications on a regular staggered grid almost trivial and enables close to ideal weak scaling of real-world applications on thousands of GPUs \[[1][JuliaCon20a], [2][JuliaCon20b], [3][JuliaCon19], [4][PASC19]\]. Moreover, ParallelStencil enables hiding communication behind computation with a simple macro call and without any particular restrictions on the package used for communication. ParallelStencil has been designed in conjunction with [ImplicitGlobalGrid.jl] for simplest possible usage by domain-scientists, rendering fast and interactive development of massively scalable high performance multi-GPU applications readily accessible to them. Furthermore, we have developed a self-contained approach for "Solving Nonlinear Multi-Physics on GPU Supercomputers with Julia" relying on ParallelStencil and [ImplicitGlobalGrid.jl] \[[1][JuliaCon20a]\]. ParallelStencil's feature to hide communication behind computation was showcased when a close to ideal weak scaling was demonstrated for a 3-D poro-hydro-mechanical real-world application on up to 1024 GPUs on the Piz Daint Supercomputer \[[1][JuliaCon20a]\]:
+ParallelStencil relies on the native kernel programming capabilities of [CUDA.jl], [AMDGPU.jl], [Metal.jl] and on [Base.Threads] for high-performance computations on GPUs and CPUs, respectively. It is seamlessly interoperable with [ImplicitGlobalGrid.jl], which renders the distributed parallelization of stencil-based GPU and CPU applications on a regular staggered grid almost trivial and enables close to ideal weak scaling of real-world applications on thousands of GPUs \[[1][JuliaCon20a], [2][JuliaCon20b], [3][JuliaCon19], [4][PASC19]\]. Moreover, ParallelStencil enables hiding communication behind computation with a simple macro call and without any particular restrictions on the package used for communication. ParallelStencil has been designed in conjunction with [ImplicitGlobalGrid.jl] for simplest possible usage by domain-scientists, rendering fast and interactive development of massively scalable high performance multi-GPU applications readily accessible to them. Furthermore, we have developed a self-contained approach for "Solving Nonlinear Multi-Physics on GPU Supercomputers with Julia" relying on ParallelStencil and [ImplicitGlobalGrid.jl] \[[1][JuliaCon20a]\]. ParallelStencil's feature to hide communication behind computation was showcased when a close to ideal weak scaling was demonstrated for a 3-D poro-hydro-mechanical real-world application on up to 1024 GPUs on the Piz Daint Supercomputer \[[1][JuliaCon20a]\]:
![Parallel efficiency of ParallelStencil with CUDA C backend](docs/images/par_eff_c_julia2.png)
@@ -33,7 +33,7 @@ Beyond traditional high-performance computing, ParallelStencil supports automati
* [References](#references)
## Parallelization and optimization with one macro call
-A simple call to `@parallel` is enough to parallelize and optimize a function and to launch it. The package used underneath for parallelization is defined in a call to `@init_parallel_stencil` beforehand. Supported are [CUDA.jl] and [AMDGPU.jl] for running on GPU and [Base.Threads] for CPU. The following example outlines how to run parallel computations on a GPU using the native kernel programming capabilities of [CUDA.jl] underneath (omitted lines are represented with `#(...)`, omitted arguments with `...`):
+A simple call to `@parallel` is enough to parallelize and optimize a function and to launch it. The package used underneath for parallelization is defined in a call to `@init_parallel_stencil` beforehand. Supported are [CUDA.jl], [AMDGPU.jl] and [Metal.jl] for running on GPU and [Base.Threads] for CPU. The following example outlines how to run parallel computations on a GPU using the native kernel programming capabilities of [CUDA.jl] underneath (omitted lines are represented with `#(...)`, omitted arguments with `...`):
```julia
#(...)
@init_parallel_stencil(CUDA,...)
@@ -554,6 +554,7 @@ Please open an issue to discuss your idea for a contribution beforehand. Further
[CellArrays.jl]: https://github.com/omlins/CellArrays.jl
[CUDA.jl]: https://github.com/JuliaGPU/CUDA.jl
[AMDGPU.jl]: https://github.com/JuliaGPU/AMDGPU.jl
+[Metal.jl]: https://github.com/JuliaGPU/Metal.jl
[Enzyme.jl]: https://github.com/EnzymeAD/Enzyme.jl
[MacroTools.jl]: https://github.com/FluxML/MacroTools.jl
[StaticArrays.jl]: https://github.com/JuliaArrays/StaticArrays.jl
diff --git a/ext/ParallelStencil_MetalExt.jl b/ext/ParallelStencil_MetalExt.jl
new file mode 100644
index 00000000..254aac1e
--- /dev/null
+++ b/ext/ParallelStencil_MetalExt.jl
@@ -0,0 +1,4 @@
+module ParallelStencil_MetalExt
+ include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "MetalExt", "shared.jl"))
+ include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "MetalExt", "allocators.jl"))
+end
\ No newline at end of file
diff --git a/src/FiniteDifferences.jl b/src/FiniteDifferences.jl
index a5266c98..82e83072 100644
--- a/src/FiniteDifferences.jl
+++ b/src/FiniteDifferences.jl
@@ -55,7 +55,7 @@ macro d2(A) @expandargs(A); esc(:( ($A[$ixi+1] - $A[$ixi]) - ($A[$ixi] -
macro all(A) @expandargs(A); esc(:( $A[$ix ] )) end
macro inn(A) @expandargs(A); esc(:( $A[$ixi ] )) end
macro av(A) @expandargs(A); esc(:(($A[$ix] + $A[$ix+1] )*0.5 )) end
-macro harm(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix] + 1.0/$A[$ix+1])*2.0 )) end
+macro harm(A) @expandargs(A); esc(:( inv(inv($A[$ix]) + inv($A[$ix+1]))*2.0 )) end
macro maxloc(A) @expandargs(A); esc(:( max( max($A[$ixi-1], $A[$ixi+1]), $A[$ixi] ) )) end
macro minloc(A) @expandargs(A); esc(:( min( min($A[$ixi-1], $A[$ixi+1]), $A[$ixi] ) )) end
@@ -172,11 +172,11 @@ macro av_xa(A) @expandargs(A); esc(:(($A[$ix ,$iy ] + $A[$ix+1,$iy ] )*0
macro av_ya(A) @expandargs(A); esc(:(($A[$ix ,$iy ] + $A[$ix ,$iy+1] )*0.5 )) end
macro av_xi(A) @expandargs(A); esc(:(($A[$ix ,$iyi ] + $A[$ix+1,$iyi ] )*0.5 )) end
macro av_yi(A) @expandargs(A); esc(:(($A[$ixi ,$iy ] + $A[$ixi ,$iy+1] )*0.5 )) end
-macro harm(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ] + 1.0/$A[$ix+1,$iy ] + 1.0/$A[$ix,$iy+1] + 1.0/$A[$ix+1,$iy+1])*4.0 )) end
-macro harm_xa(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ] + 1.0/$A[$ix+1,$iy ] )*2.0 )) end
-macro harm_ya(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ] + 1.0/$A[$ix ,$iy+1] )*2.0 )) end
-macro harm_xi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iyi ] + 1.0/$A[$ix+1,$iyi ] )*2.0 )) end
-macro harm_yi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ixi ,$iy ] + 1.0/$A[$ixi ,$iy+1] )*2.0 )) end
+macro harm(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ]) + inv($A[$ix+1,$iy ]) + inv($A[$ix,$iy+1]) + inv($A[$ix+1,$iy+1]))*4.0 )) end
+macro harm_xa(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ]) + inv($A[$ix+1,$iy ]))*2.0 )) end
+macro harm_ya(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ]) + inv($A[$ix ,$iy+1]))*2.0 )) end
+macro harm_xi(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iyi ]) + inv($A[$ix+1,$iyi ]))*2.0 )) end
+macro harm_yi(A) @expandargs(A); esc(:( inv(inv($A[$ixi ,$iy ]) + inv($A[$ixi ,$iy+1]))*2.0 )) end
macro maxloc(A) @expandargs(A); esc(:( max( max( max($A[$ixi-1,$iyi ], $A[$ixi+1,$iyi ]) , $A[$ixi ,$iyi ] ),
max($A[$ixi ,$iyi-1], $A[$ixi ,$iyi+1]) ) )) end
macro minloc(A) @expandargs(A); esc(:( min( min( min($A[$ixi-1,$iyi ], $A[$ixi+1,$iyi ]) , $A[$ixi ,$iyi ] ),
@@ -361,28 +361,28 @@ macro av_xzi(A) @expandargs(A); esc(:(($A[$ix ,$iyi ,$iz ] + $A[$ix+1,$iyi
$A[$ix ,$iyi ,$iz+1] + $A[$ix+1,$iyi ,$iz+1] )*0.25 )) end
macro av_yzi(A) @expandargs(A); esc(:(($A[$ixi ,$iy ,$iz ] + $A[$ixi ,$iy+1,$iz ] +
$A[$ixi ,$iy ,$iz+1] + $A[$ixi ,$iy+1,$iz+1] )*0.25 )) end
-macro harm(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix+1,$iy ,$iz ] +
- 1.0/$A[$ix+1,$iy+1,$iz ] + 1.0/$A[$ix+1,$iy+1,$iz+1] +
- 1.0/$A[$ix ,$iy+1,$iz+1] + 1.0/$A[$ix ,$iy ,$iz+1] +
- 1.0/$A[$ix+1,$iy ,$iz+1] + 1.0/$A[$ix ,$iy+1,$iz ] )*8.0)) end
-macro harm_xa(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix+1,$iy ,$iz ] )*2.0 )) end
-macro harm_ya(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix ,$iy+1,$iz ] )*2.0 )) end
-macro harm_za(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix ,$iy ,$iz+1] )*2.0 )) end
-macro harm_xi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iyi ,$izi ] + 1.0/$A[$ix+1,$iyi ,$izi ] )*2.0 )) end
-macro harm_yi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ixi ,$iy ,$izi ] + 1.0/$A[$ixi ,$iy+1,$izi ] )*2.0 )) end
-macro harm_zi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ixi ,$iyi ,$iz ] + 1.0/$A[$ixi ,$iyi ,$iz+1] )*2.0 )) end
-macro harm_xya(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix+1,$iy ,$iz ] +
- 1.0/$A[$ix ,$iy+1,$iz ] + 1.0/$A[$ix+1,$iy+1,$iz ] )*4.0 )) end
-macro harm_xza(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix+1,$iy ,$iz ] +
- 1.0/$A[$ix ,$iy ,$iz+1] + 1.0/$A[$ix+1,$iy ,$iz+1] )*4.0 )) end
-macro harm_yza(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$iz ] + 1.0/$A[$ix ,$iy+1,$iz ] +
- 1.0/$A[$ix ,$iy ,$iz+1] + 1.0/$A[$ix ,$iy+1,$iz+1] )*4.0 )) end
-macro harm_xyi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iy ,$izi ] + 1.0/$A[$ix+1,$iy ,$izi ] +
- 1.0/$A[$ix ,$iy+1,$izi ] + 1.0/$A[$ix+1,$iy+1,$izi ] )*4.0 )) end
-macro harm_xzi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ix ,$iyi ,$iz ] + 1.0/$A[$ix+1,$iyi ,$iz ] +
- 1.0/$A[$ix ,$iyi ,$iz+1] + 1.0/$A[$ix+1,$iyi ,$iz+1] )*4.0 )) end
-macro harm_yzi(A) @expandargs(A); esc(:(1.0/(1.0/$A[$ixi ,$iy ,$iz ] + 1.0/$A[$ixi ,$iy+1,$iz ] +
- 1.0/$A[$ixi ,$iy ,$iz+1] + 1.0/$A[$ixi ,$iy+1,$iz+1] )*4.0 )) end
+macro harm(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix+1,$iy ,$iz ]) +
+ inv($A[$ix+1,$iy+1,$iz ]) + inv($A[$ix+1,$iy+1,$iz+1]) +
+ inv($A[$ix ,$iy+1,$iz+1]) + inv($A[$ix ,$iy ,$iz+1]) +
+ inv($A[$ix+1,$iy ,$iz+1]) + inv($A[$ix ,$iy+1,$iz ]) )*8.0 )) end
+macro harm_xa(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix+1,$iy ,$iz ]) )*2.0 )) end
+macro harm_ya(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix ,$iy+1,$iz ]) )*2.0 )) end
+macro harm_za(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix ,$iy ,$iz+1]) )*2.0 )) end
+macro harm_xi(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iyi ,$izi ]) + inv($A[$ix+1,$iyi ,$izi ]) )*2.0 )) end
+macro harm_yi(A) @expandargs(A); esc(:( inv(inv($A[$ixi ,$iy ,$izi ]) + inv($A[$ixi ,$iy+1,$izi ]) )*2.0 )) end
+macro harm_zi(A) @expandargs(A); esc(:( inv(inv($A[$ixi ,$iyi ,$iz ]) + inv($A[$ixi ,$iyi ,$iz+1]) )*2.0 )) end
+macro harm_xya(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix+1,$iy ,$iz ]) +
+ inv($A[$ix ,$iy+1,$iz ]) + inv($A[$ix+1,$iy+1,$iz ]) )*4.0 )) end
+macro harm_xza(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix+1,$iy ,$iz ]) +
+ inv($A[$ix ,$iy ,$iz+1]) + inv($A[$ix+1,$iy ,$iz+1]) )*4.0 )) end
+macro harm_yza(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$iz ]) + inv($A[$ix ,$iy+1,$iz ]) +
+ inv($A[$ix ,$iy ,$iz+1]) + inv($A[$ix ,$iy+1,$iz+1]) )*4.0 )) end
+macro harm_xyi(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iy ,$izi ]) + inv($A[$ix+1,$iy ,$izi ]) +
+ inv($A[$ix ,$iy+1,$izi ]) + inv($A[$ix+1,$iy+1,$izi ]) )*4.0 )) end
+macro harm_xzi(A) @expandargs(A); esc(:( inv(inv($A[$ix ,$iyi ,$iz ]) + inv($A[$ix+1,$iyi ,$iz ]) +
+ inv($A[$ix ,$iyi ,$iz+1]) + inv($A[$ix+1,$iyi ,$iz+1]) )*4.0 )) end
+macro harm_yzi(A) @expandargs(A); esc(:( inv(inv($A[$ixi ,$iy ,$iz ]) + inv($A[$ixi ,$iy+1,$iz ]) +
+ inv($A[$ixi ,$iy ,$iz+1]) + inv($A[$ixi ,$iy+1,$iz+1]) )*4.0 )) end
macro maxloc(A) @expandargs(A); esc(:( max( max( max( max($A[$ixi-1,$iyi ,$izi ], $A[$ixi+1,$iyi ,$izi ]) , $A[$ixi ,$iyi ,$izi ] ),
max($A[$ixi ,$iyi-1,$izi ], $A[$ixi ,$iyi+1,$izi ]) ),
max($A[$ixi ,$iyi ,$izi-1], $A[$ixi ,$iyi ,$izi+1]) ) )) end
diff --git a/src/ParallelKernel/Data.jl b/src/ParallelKernel/Data.jl
index ed8c357c..9ac81b58 100644
--- a/src/ParallelKernel/Data.jl
+++ b/src/ParallelKernel/Data.jl
@@ -31,12 +31,12 @@ The type of indices used in parallel kernels.
--------------------------------------------------------------------------------
Data.Array{ndims}
-Expands to `Data.Array{numbertype, ndims}`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the datatype `Data.Array` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (Array for Threads or Polyester, CUDA.CuArray or CUDA.CuDeviceArray for CUDA and AMDGPU.ROCArray or AMDGPU.ROCDeviceArray for AMDGPU; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CUDA.CuArray and AMDGPU.ROCArray automatically to CUDA.CuDeviceArray and AMDGPU.ROCDeviceArray in kernels when required).
+Expands to `Data.Array{numbertype, ndims}`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the datatype `Data.Array` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (Array for Threads or Polyester, CUDA.CuArray or CUDA.CuDeviceArray for CUDA, AMDGPU.ROCArray or AMDGPU.ROCDeviceArray for AMDGPU and Metal.MtlArray or Metal.MtlDeviceArray for Metal; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CUDA.CuArray, AMDGPU.ROCArray and Metal.MtlArray automatically to CUDA.CuDeviceArray, AMDGPU.ROCDeviceArray and Metal.MtlDeviceArray in kernels when required).
--------------------------------------------------------------------------------
Data.CellArray{ndims}
-Expands to `Data.CellArray{numbertype, ndims}`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the datatype `Data.CellArray` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (CPUCellArray for Threads or Polyester, CuCellArray or CuDeviceCellArray for CUDA and ROCCellArray or ROCDeviceCellArray for AMDGPU; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CellArray automatically to DeviceCellArray when required).
+Expands to `Data.CellArray{numbertype, ndims}`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the datatype `Data.CellArray` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (CPUCellArray for Threads or Polyester, CuCellArray or CuDeviceCellArray for CUDA, ROCCellArray or ROCDeviceCellArray for AMDGPU and MtlCellArray or MtlDeviceCellArray for Metal; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CellArray automatically to DeviceCellArray when required).
--------------------------------------------------------------------------------
Data.Cell{S}
@@ -143,12 +143,12 @@ The type of indices used in parallel kernels.
--------------------------------------------------------------------------------
Data.Array{numbertype, ndims}
-The datatype `Data.Array` is automatically chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (Array for Threads or Polyester, CUDA.CuArray or CUDA.CuDeviceArray for CUDA and AMDGPU.ROCArray or AMDGPU.ROCDeviceArray for AMDGPU; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CUDA.CuArray and AMDGPU.ROCArray automatically to CUDA.CuDeviceArray and AMDGPU.ROCDeviceArray in kernels when required).
+The datatype `Data.Array` is automatically chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (Array for Threads or Polyester, CUDA.CuArray or CUDA.CuDeviceArray for CUDA, AMDGPU.ROCArray or AMDGPU.ROCDeviceArray for AMDGPU and Metal.MtlArray or Metal.MtlDeviceArray for Metal; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CUDA.CuArray, AMDGPU.ROCArray and Metal.MtlArray automatically to CUDA.CuDeviceArray, AMDGPU.ROCDeviceArray and Metal.MtlDeviceArray in kernels when required).
--------------------------------------------------------------------------------
Data.CellArray{numbertype, ndims}
-The datatype `Data.CellArray` is automatically chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (CPUCellArray for Threads or Polyester, CuCellArray or CuDeviceCellArray for CUDA and ROCCellArray or ROCDeviceCellArray for AMDGPU; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CellArray automatically to DeviceCellArray in kernels when required).
+The datatype `Data.CellArray` is automatically chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (CPUCellArray for Threads or Polyester, CuCellArray or CuDeviceCellArray for CUDA, ROCCellArray or ROCDeviceCellArray for AMDGPU and MtlCellArray or MetalDeviceCellArray for Metal; [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CellArray automatically to DeviceCellArray in kernels when required).
--------------------------------------------------------------------------------
Data.Cell{numbertype, S}
@@ -422,6 +422,86 @@ function TData_Device_amdgpu()
end)
end
+# Metal
+
+function Data_metal(numbertype::DataType, indextype::DataType)
+ Data_module = if (numbertype == NUMBERTYPE_NONE)
+ :(baremodule $MODULENAME_DATA # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
+ const Index = $indextype
+ const Array{T, N} = Metal.MtlArray{T, N}
+ const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
+ const CellArray{T_elem, N, B} = MtlCellArray{<:Cell{T_elem},N,B,T_elem}
+ $(Data_xpu_exprs(numbertype))
+ $(Data_Device_metal(numbertype, indextype))
+ $(Data_Fields(numbertype, indextype))
+ end)
+ else
+ :(baremodule $MODULENAME_DATA
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
+ const Index = $indextype
+ const Number = $numbertype
+ const Array{N} = Metal.MtlArray{$numbertype, N}
+ const Cell{S} = Union{StaticArrays.SArray{S, $numbertype}, StaticArrays.FieldArray{S, $numbertype}}
+ const CellArray{N, B} = MtlCellArray{<:Cell,N,B,$numbertype}
+ $(Data_xpu_exprs(numbertype))
+ $(Data_Device_metal(numbertype, indextype))
+ $(Data_Fields(numbertype, indextype))
+ end)
+ end
+ return prewalk(rmlines, flatten(Data_module))
+end
+
+function TData_metal()
+ TData_module = :(
+ baremodule $MODULENAME_TDATA
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
+ const Array{T, N} = Metal.MtlArray{T, N}
+ const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
+ const CellArray{T_elem, N, B} = MtlCellArray{<:Cell{T_elem},N,B,T_elem}
+ $(TData_xpu_exprs())
+ $(TData_Device_metal())
+ $(TData_Fields())
+ end
+ )
+ return prewalk(rmlines, flatten(TData_module))
+end
+
+function Data_Device_metal(numbertype::DataType, indextype::DataType)
+ Device_module = if (numbertype == NUMBERTYPE_NONE)
+ :(baremodule $MODULENAME_DEVICE
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const Index = $indextype
+ const Array{T, N} = Metal.MtlDeviceArray{T, N}
+ const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
+ const CellArray{T_elem, N, B} = CellArrays.CellArray{<:Cell{T_elem},N,B,<:Metal.MtlDeviceArray{T_elem,CellArrays._N}}
+ $(Data_xpu_exprs(numbertype))
+ end)
+ else
+ :(baremodule $MODULENAME_DEVICE
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const Index = $indextype
+ const Array{N} = Metal.MtlDeviceArray{$numbertype, N}
+ const Cell{S} = Union{StaticArrays.SArray{S, $numbertype}, StaticArrays.FieldArray{S, $numbertype}}
+ const CellArray{N, B} = CellArrays.CellArray{<:Cell,N,B,<:Metal.MtlDeviceArray{$numbertype,CellArrays._N}}
+ $(Data_xpu_exprs(numbertype))
+ end)
+ end
+ return Device_module
+end
+
+function TData_Device_metal()
+ :(baremodule $MODULENAME_DEVICE
+ import Base, Metal, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
+ const Array{T, N} = Metal.MtlDeviceArray{T, N}
+ const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
+ const CellArray{T_elem, N, B} = CellArrays.CellArray{<:Cell{T_elem},N,B,<:Metal.MtlDeviceArray{T_elem,CellArrays._N}}
+ $(TData_xpu_exprs())
+ end)
+end
# CPU
diff --git a/src/ParallelKernel/MetalExt/allocators.jl b/src/ParallelKernel/MetalExt/allocators.jl
new file mode 100644
index 00000000..f2251f51
--- /dev/null
+++ b/src/ParallelKernel/MetalExt/allocators.jl
@@ -0,0 +1,29 @@
+## RUNTIME ALLOCATOR FUNCTIONS
+
+ParallelStencil.ParallelKernel.zeros_metal(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_metal(T); Metal.zeros(T, args...)) # (blocklength is ignored if neither celldims nor celltype is set)
+ParallelStencil.ParallelKernel.ones_metal(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_metal(T); Metal.ones(T, args...))
+ParallelStencil.ParallelKernel.rand_metal(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = MtlArray(rand_cpu(T, blocklength, args...))
+ParallelStencil.ParallelKernel.falses_metal(::Type{T}, blocklength, args...) where {T<:Bool} = Metal.falses(args...)
+ParallelStencil.ParallelKernel.trues_metal(::Type{T}, blocklength, args...) where {T<:Bool} = Metal.trues(args...)
+ParallelStencil.ParallelKernel.fill_metal(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = MtlArray(fill_cpu(T, blocklength, args...))
+
+ParallelStencil.ParallelKernel.zeros_metal(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_metal(T); fill_metal(T, blocklength, 0, args...))
+ParallelStencil.ParallelKernel.ones_metal(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_metal(T); fill_metal(T, blocklength, 1, args...))
+ParallelStencil.ParallelKernel.rand_metal(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_metal(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, Metal.MtlArray{eltype(T),3}}(Metal.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims))
+ParallelStencil.ParallelKernel.rand_metal(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_metal(T, blocklength, dims)
+ParallelStencil.ParallelKernel.falses_metal(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_metal(T, blocklength, false, args...)
+ParallelStencil.ParallelKernel.trues_metal(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_metal(T, blocklength, true, args...)
+
+function ParallelStencil.ParallelKernel.fill_metal(::Type{T}, ::Val{B}, x, args...) where {T <: Union{SArray,FieldArray}, B}
+ if (!(eltype(x) <: Number) || (eltype(x) == Bool)) && (eltype(x) != eltype(T)) @ArgumentError("fill: the (element) type of argument 'x' is not a normal number type ($(eltype(x))), but does not match the obtained (default) 'eltype' ($(eltype(T))); automatic conversion to $(eltype(T)) is therefore not attempted. Set the keyword argument 'eltype' accordingly to the element type of 'x' or pass an 'x' of a different (element) type.") end
+ check_datatype_metal(T, Bool, Enum)
+ if (length(x) == 1) cell = convert(T, fill(convert(eltype(T), x), size(T)))
+ elseif (length(x) == length(T)) cell = convert(T, x)
+ else @ArgumentError("fill: argument 'x' contains the wrong number of elements ($(length(x))). It must be a scalar or contain the number of elements defined by 'celldims'.")
+ end
+ return CellArrays.fill!(MtlCellArray{T,B}(undef, args...), cell)
+end
+
+ParallelStencil.ParallelKernel.fill_metal!(A, x) = Metal.fill!(A, construct_cell(A, x))
+
+check_datatype_metal(args...) = check_datatype(args..., INT_METAL)
\ No newline at end of file
diff --git a/src/ParallelKernel/MetalExt/defaults.jl b/src/ParallelKernel/MetalExt/defaults.jl
new file mode 100644
index 00000000..abc3e224
--- /dev/null
+++ b/src/ParallelKernel/MetalExt/defaults.jl
@@ -0,0 +1,18 @@
+const ERRMSG_METALEXT_NOT_LOADED = "the Metal extension was not loaded. Make sure to import Metal before ParallelStencil."
+
+# shared.jl
+
+function get_priority_metalstream end
+function get_metalstream end
+
+# allocators
+
+zeros_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+ones_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+rand_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+falses_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+trues_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+fill_metal(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+fill_metal!(arg...) = @NotLoadedError(ERRMSG_METALEXT_NOT_LOADED)
+
+
diff --git a/src/ParallelKernel/MetalExt/shared.jl b/src/ParallelKernel/MetalExt/shared.jl
new file mode 100644
index 00000000..60b71499
--- /dev/null
+++ b/src/ParallelKernel/MetalExt/shared.jl
@@ -0,0 +1,30 @@
+import ParallelStencil
+import ParallelStencil.ParallelKernel: INT_METAL, rand_cpu, fill_cpu, construct_cell, check_datatype, rand_metal, fill_metal
+using ParallelStencil.ParallelKernel.Exceptions
+using Metal, CellArrays, StaticArrays
+import Metal.MTL
+
+@define_MtlCellArray
+
+## FUNCTIONS TO CHECK EXTENSIONS SUPPORT
+ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_MetalExt}) = true
+
+## FUNCTIONS TO GET CREATE AND MANAGE METAL QUEUES
+ParallelStencil.ParallelKernel.get_priority_metalstream(arg...) = get_priority_metalstream(arg...)
+ParallelStencil.ParallelKernel.get_metalstream(arg...) = get_metalstream(arg...)
+
+let
+ global get_priority_metalstream, get_metalstream
+ priority_metalqueues = Array{MTL.MTLCommandQueue}(undef, 0)
+ metalqueues = Array{MTL.MTLCommandQueue}(undef, 0)
+
+ function get_priority_metalstream(id::Integer)
+ while (id > length(priority_metalqueues)) push!(priority_metalqueues, MTL.MTLCommandQueue(Metal.device())) end # No priority setting available in Metal queues.
+ return priority_metalqueues[id]
+ end
+
+ function get_metalstream(id::Integer)
+ while (id > length(metalqueues)) push!(metalqueues, MTL.MTLCommandQueue(Metal.device())) end
+ return metalqueues[id]
+ end
+end
\ No newline at end of file
diff --git a/src/ParallelKernel/ParallelKernel.jl b/src/ParallelKernel/ParallelKernel.jl
index e901acbd..740e1b9e 100644
--- a/src/ParallelKernel/ParallelKernel.jl
+++ b/src/ParallelKernel/ParallelKernel.jl
@@ -54,6 +54,7 @@ include(joinpath("EnzymeExt", "AD.jl"))
## Alphabetical include of defaults for extensions
include(joinpath("AMDGPUExt", "defaults.jl"))
include(joinpath("CUDAExt", "defaults.jl"))
+include(joinpath("MetalExt", "defaults.jl"))
## Include of constant parameters, types and syntax sugar shared in ParallelKernel module only
include("shared.jl")
diff --git a/src/ParallelKernel/allocators.jl b/src/ParallelKernel/allocators.jl
index 90b8e240..ca47db03 100644
--- a/src/ParallelKernel/allocators.jl
+++ b/src/ParallelKernel/allocators.jl
@@ -3,7 +3,7 @@ const ZEROS_DOC = """
@zeros(args...)
@zeros(args..., )
-Call `zeros(eltype, args...)`, where `eltype` is by default the `numbertype` selected with [`@init_parallel_kernel`](@ref) and the function `zeros` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (zeros for Threads or Polyester, CUDA.zeros for CUDA and AMDGPU.zeros for AMDGPU).
+Call `zeros(eltype, args...)`, where `eltype` is by default the `numbertype` selected with [`@init_parallel_kernel`](@ref) and the function `zeros` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (zeros for Threads or Polyester, CUDA.zeros for CUDA, AMDGPU.zeros for AMDGPU and Metal.zeros for Metal).
!!! note "Advanced"
The `eltype` can be explicitly passed as keyword argument in order to be used instead of the default `numbertype` chosen with [`@init_parallel_kernel`](@ref). If no default `numbertype` was chosen [`@init_parallel_kernel`](@ref), then the keyword argument `eltype` is mandatory. This needs to be used with care to ensure that no datatype conversions occur in performance critical computations.
@@ -31,7 +31,7 @@ const ONES_DOC = """
@ones(args...)
@ones(args..., )
-Call `ones(eltype, args...)`, where `eltype` is by default the `numbertype` selected with [`@init_parallel_kernel`](@ref) and the function `ones` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (ones for Threads or Polyester, CUDA.ones for CUDA and AMDGPU.ones for AMDGPU).
+Call `ones(eltype, args...)`, where `eltype` is by default the `numbertype` selected with [`@init_parallel_kernel`](@ref) and the function `ones` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (ones for Threads or Polyester, CUDA.ones for CUDA, AMDGPU.ones for AMDGPU and Metal.ones for Metal).
!!! note "Advanced"
The `eltype` can be explicitly passed as keyword argument in order to be used instead of the default `numbertype` chosen with [`@init_parallel_kernel`](@ref). If no default `numbertype` was chosen [`@init_parallel_kernel`](@ref), then the keyword argument `eltype` is mandatory. This needs to be used with care to ensure that no datatype conversions occur in performance critical computations.
@@ -240,6 +240,13 @@ macro falses_amdgpu(args...) check_initialized(__module__); esc(_falses(__mod
macro trues_amdgpu(args...) check_initialized(__module__); esc(_trues(__module__, args...; package=PKG_AMDGPU)); end
macro fill_amdgpu(args...) check_initialized(__module__); esc(_fill(__module__, args...; package=PKG_AMDGPU)); end
macro fill!_amdgpu(args...) check_initialized(__module__); esc(_fill!(__module__, args...; package=PKG_AMDGPU)); end
+macro zeros_metal(args...) check_initialized(__module__); esc(_zeros(__module__, args...; package=PKG_METAL)); end
+macro ones_metal(args...) check_initialized(__module__); esc(_ones(__module__, args...; package=PKG_METAL)); end
+macro rand_metal(args...) check_initialized(__module__); esc(_rand(__module__, args...; package=PKG_METAL)); end
+macro falses_metal(args...) check_initialized(__module__); esc(_falses(__module__, args...; package=PKG_METAL)); end
+macro trues_metal(args...) check_initialized(__module__); esc(_trues(__module__, args...; package=PKG_METAL)); end
+macro fill_metal(args...) check_initialized(__module__); esc(_fill(__module__, args...; package=PKG_METAL)); end
+macro fill!_metal(args...) check_initialized(__module__); esc(_fill!(__module__, args...; package=PKG_METAL)); end
macro zeros_threads(args...) check_initialized(__module__); esc(_zeros(__module__, args...; package=PKG_THREADS)); end
macro ones_threads(args...) check_initialized(__module__); esc(_ones(__module__, args...; package=PKG_THREADS)); end
macro rand_threads(args...) check_initialized(__module__); esc(_rand(__module__, args...; package=PKG_THREADS)); end
@@ -274,6 +281,7 @@ function _zeros(caller::Module, args...; eltype=nothing, celldims=nothing, cellt
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.zeros_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.zeros_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.zeros_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.zeros_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -284,6 +292,7 @@ function _ones(caller::Module, args...; eltype=nothing, celldims=nothing, cellty
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.ones_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.ones_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.ones_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.ones_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -294,6 +303,7 @@ function _rand(caller::Module, args...; eltype=nothing, celldims=nothing, cellty
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.rand_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.rand_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.rand_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.rand_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -304,6 +314,7 @@ function _falses(caller::Module, args...; celldims=nothing, blocklength=nothing,
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.falses_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.falses_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.falses_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.falses_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -314,6 +325,7 @@ function _trues(caller::Module, args...; celldims=nothing, blocklength=nothing,
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.trues_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.trues_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.trues_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.trues_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -324,6 +336,7 @@ function _fill(caller::Module, args...; eltype=nothing, celldims=nothing, cellty
blocklength = determine_blocklength(blocklength, package)
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.fill_cuda($celltype, $blocklength, $(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.fill_amdgpu($celltype, $blocklength, $(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.fill_metal($celltype, $blocklength, $(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.fill_cpu($celltype, $blocklength, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -332,6 +345,7 @@ end
function _fill!(caller::Module, args...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) return :(ParallelStencil.ParallelKernel.fill_cuda!($(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.fill_amdgpu!($(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.fill_metal!($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.fill_cpu!($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
diff --git a/src/ParallelKernel/hide_communication.jl b/src/ParallelKernel/hide_communication.jl
index 2360cc27..0eb58fc6 100644
--- a/src/ParallelKernel/hide_communication.jl
+++ b/src/ParallelKernel/hide_communication.jl
@@ -121,6 +121,7 @@ end
function get_priority_stream(caller::Module, args::Union{Integer,Symbol,Expr}...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) get_priority_stream_cuda(args...)
elseif (package == PKG_AMDGPU) get_priority_stream_amdgpu(args...)
+ elseif (package == PKG_METAL) get_priority_stream_metal(args...)
else @ArgumentError("unsupported GPU package (obtained: $package).")
end
end
@@ -128,6 +129,7 @@ end
function get_stream(caller::Module, args::Union{Integer,Symbol,Expr}...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) get_stream_cuda(args...)
elseif (package == PKG_AMDGPU) get_stream_amdgpu(args...)
+ elseif (package == PKG_METAL) get_stream_metal(args...)
else @ArgumentError("unsupported GPU package (obtained: $package).")
end
end
@@ -222,8 +224,10 @@ end
get_priority_stream_cuda(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_priority_custream($id))
get_priority_stream_amdgpu(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_priority_rocstream($id))
+get_priority_stream_metal(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_priority_metalstream($id))
get_stream_cuda(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_custream($id))
get_stream_amdgpu(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_rocstream($id))
+get_stream_metal(id::Union{Integer,Symbol,Expr}) = return :(ParallelStencil.ParallelKernel.get_metalstream($id))
## FUNCTIONS TO EXTRACT AND PROCESS COMPUTATION AND BOUNDARY CONDITIONS CALLS / COMMUNICATION CALLS
diff --git a/src/ParallelKernel/init_parallel_kernel.jl b/src/ParallelKernel/init_parallel_kernel.jl
index e91ed867..c1535434 100644
--- a/src/ParallelKernel/init_parallel_kernel.jl
+++ b/src/ParallelKernel/init_parallel_kernel.jl
@@ -4,7 +4,7 @@
Initialize the package ParallelKernel, giving access to its main functionality. Creates a module `Data` in the module where `@init_parallel_kernel` is called from. The module `Data` contains the types as `Data.Number`, `Data.Array` and `Data.CellArray` (type `?Data` *after* calling `@init_parallel_kernel` to see the full description of the module).
# Arguments
-- `package::Module`: the package used for parallelization (CUDA or AMDGPU for GPU, or Threads or Polyester for CPU).
+- `package::Module`: the package used for parallelization (CUDA or AMDGPU or Metal for GPU, or Threads or Polyester for CPU).
- `numbertype::DataType`: the type of numbers used by @zeros, @ones, @rand and @fill and in all array types of module `Data` (e.g. Float32 or Float64). It is contained in `Data.Number` after @init_parallel_kernel.
- `inbounds::Bool=false`: whether to apply `@inbounds` to the kernels by default (overwritable in each kernel definition).
@@ -36,6 +36,11 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
indextype = INT_AMDGPU
data_module = Data_amdgpu(numbertype, indextype)
tdata_module = TData_amdgpu()
+ elseif package == PKG_METAL
+ if (isinteractive() && !is_installed("Metal")) @NotInstalledError("Metal was selected as package for parallelization, but Metal.jl is not installed. Metal functionality is provided as an extension of $parent_module and Metal.jl needs therefore to be installed independently (type `add Metal` in the julia package manager).") end
+ indextype = INT_METAL
+ data_module = Data_metal(numbertype, indextype)
+ tdata_module = TData_metal()
elseif package == PKG_POLYESTER
if (isinteractive() && !is_installed("Polyester")) @NotInstalledError("Polyester was selected as package for parallelization, but Polyester.jl is not installed. Multi-threading using Polyester is provided as an extension of $parent_module and Polyester.jl needs therefore to be installed independently (type `add Polyester` in the julia package manager).") end
indextype = INT_POLYESTER
diff --git a/src/ParallelKernel/kernel_language.jl b/src/ParallelKernel/kernel_language.jl
index a714a95a..ca3b977f 100644
--- a/src/ParallelKernel/kernel_language.jl
+++ b/src/ParallelKernel/kernel_language.jl
@@ -172,6 +172,7 @@ end
function gridDim(caller::Module, args...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) return :(CUDA.gridDim($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.gridGroupDim($(args...)))
+ elseif (package == PKG_METAL) return :(Metal.threadgroups_per_grid_3d($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@gridDim_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -180,6 +181,7 @@ end
function blockIdx(caller::Module, args...; package::Symbol=get_package(caller)) #NOTE: the CPU implementation relies on the fact that ranges are always of type UnitRange. If this changes, then this function needs to be adapted.
if (package == PKG_CUDA) return :(CUDA.blockIdx($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.workgroupIdx($(args...)))
+ elseif (package == PKG_METAL) return :(Metal.threadgroup_position_in_grid_3d($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@blockIdx_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -188,6 +190,7 @@ end
function blockDim(caller::Module, args...; package::Symbol=get_package(caller)) #NOTE: the CPU implementation follows the model that no threads are grouped into blocks, i.e. that each block contains only 1 thread (with thread ID 1). The parallelization happens only over the blocks.
if (package == PKG_CUDA) return :(CUDA.blockDim($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.workgroupDim($(args...)))
+ elseif (package == PKG_METAL) return :(Metal.threads_per_threadgroup_3d($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@blockDim_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -196,6 +199,7 @@ end
function threadIdx(caller::Module, args...; package::Symbol=get_package(caller)) #NOTE: the CPU implementation follows the model that no threads are grouped into blocks, i.e. that each block contains only 1 thread (with thread ID 1). The parallelization happens only over the blocks.
if (package == PKG_CUDA) return :(CUDA.threadIdx($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.workitemIdx($(args...)))
+ elseif (package == PKG_METAL) return :(Metal.thread_position_in_threadgroup_3d($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@threadIdx_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -207,6 +211,7 @@ end
function sync_threads(caller::Module, args...; package::Symbol=get_package(caller)) #NOTE: the CPU implementation follows the model that no threads are grouped into blocks, i.e. that each block contains only 1 thread (with thread ID 1). The parallelization happens only over the blocks. Synchronization within a block is therefore not needed (as it contains only one thread).
if (package == PKG_CUDA) return :(CUDA.sync_threads($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.sync_workgroup($(args...)))
+ elseif (package == PKG_METAL) return :(Metal.threadgroup_barrier($(args...); flag=Metal.MemoryFlagThreadGroup))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@sync_threads_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -218,6 +223,7 @@ end
function sharedMem(caller::Module, args...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) return :(CUDA.@cuDynamicSharedMem($(args...)))
elseif (package == PKG_AMDGPU) return :(ParallelStencil.ParallelKernel.@sharedMem_amdgpu($(args...)))
+ elseif (package == PKG_METAL) return :(ParallelStencil.ParallelKernel.@sharedMem_metal($(args...)))
elseif iscpu(package) return :(ParallelStencil.ParallelKernel.@sharedMem_cpu($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -227,12 +233,16 @@ macro sharedMem_amdgpu(T, dims) esc(:(AMDGPU.@ROCDynamicLocalArray($T, $dims, fa
macro sharedMem_amdgpu(T, dims, offset) esc(:(ParallelStencil.ParallelKernel.@sharedMem_amdgpu($T, $dims))) end
+macro sharedMem_metal(T, dims) :(Metal.MtlThreadGroupArray($T, $dims)); end
+
+macro sharedMem_metal(T, dims, offset) esc(:(ParallelStencil.ParallelKernel.@sharedMem_metal($T, $dims))) end
## FUNCTIONS FOR PRINTING
function pk_show(caller::Module, args...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) return :(CUDA.@cushow($(args...)))
elseif (package == PKG_AMDGPU) @KeywordArgumentError("this functionality is not yet supported in AMDGPU.jl.")
+ elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.")
elseif iscpu(package) return :(Base.@show($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
@@ -241,6 +251,7 @@ end
function pk_println(caller::Module, args...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) return :(CUDA.@cuprintln($(args...)))
elseif (package == PKG_AMDGPU) return :(AMDGPU.@rocprintln($(args...)))
+ elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.")
elseif iscpu(package) return :(Base.println($(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
diff --git a/src/ParallelKernel/parallel.jl b/src/ParallelKernel/parallel.jl
index 44dfd967..8fb54f5b 100644
--- a/src/ParallelKernel/parallel.jl
+++ b/src/ParallelKernel/parallel.jl
@@ -15,8 +15,8 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `kernelcall`: a call to a kernel that is declared parallel.
!!! note "Advanced optional arguments"
- `ranges::Tuple{UnitRange{},UnitRange{},UnitRange{}} | Tuple{UnitRange{},UnitRange{}} | Tuple{UnitRange{}} | UnitRange{}`: the ranges of indices in each dimension for which computations must be performed.
- - `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- - `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
+ - `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA, AMDGPU or Metal was selected with [`@init_parallel_kernel`](@ref).
+ - `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA, AMDGPU or Metal was selected with [`@init_parallel_kernel`](@ref).
# Keyword arguments
!!! note "Advanced"
@@ -24,7 +24,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `ad_mode=Enzyme.Reverse`: the automatic differentiation mode (see the documentation of Enzyme.jl for more information).
- `ad_annotations=()`: Enzyme variable annotations for automatic differentiation in the format `(=, =, ...)`, where `` can be a single variable or a tuple of variables (e.g., `ad_annotations=(Duplicated=B, Active=(a,b))`). Currently supported annotations are: $(keys(AD_SUPPORTED_ANNOTATIONS)).
- `configcall=kernelcall`: a call to a kernel that is declared parallel, which is used for determining the kernel launch parameters. This keyword is useful, e.g., for generic automatic differentiation using the low-level submodule [`AD`](@ref).
- - `backendkwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads or Polyester).
+ - `backendkwargs...`: keyword arguments to be passed further to CUDA, AMDGPU or Metal (ignored for Threads or Polyester).
!!! note "Performance note"
Kernel launch parameters are automatically defined with heuristics, where not defined with optional kernel arguments. For CUDA and AMDGPU, `nthreads` is typically set to (32,8,1) and `nblocks` accordingly to ensure that enough threads are launched.
@@ -90,18 +90,22 @@ macro synchronize(args...) check_initialized(__module__); esc(synchronize(__modu
macro parallel_cuda(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__module__, args...; package=PKG_CUDA)); end
macro parallel_amdgpu(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__module__, args...; package=PKG_AMDGPU)); end
+macro parallel_metal(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__module__, args...; package=PKG_METAL)); end
macro parallel_threads(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__module__, args...; package=PKG_THREADS)); end
macro parallel_polyester(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__module__, args...; package=PKG_POLYESTER)); end
macro parallel_indices_cuda(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...; package=PKG_CUDA)); end
macro parallel_indices_amdgpu(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...; package=PKG_AMDGPU)); end
+macro parallel_indices_metal(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...; package=PKG_METAL)); end
macro parallel_indices_threads(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...; package=PKG_THREADS)); end
macro parallel_indices_polyester(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...; package=PKG_POLYESTER)); end
macro parallel_async_cuda(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__module__, args...; package=PKG_CUDA)); end
macro parallel_async_amdgpu(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__module__, args...; package=PKG_AMDGPU)); end
+macro parallel_async_metal(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__module__, args...; package=PKG_METAL)); end
macro parallel_async_threads(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__module__, args...; package=PKG_THREADS)); end
macro parallel_async_polyester(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__module__, args...; package=PKG_POLYESTER)); end
macro synchronize_cuda(args...) check_initialized(__module__); esc(synchronize(__module__, args...; package=PKG_CUDA)); end
macro synchronize_amdgpu(args...) check_initialized(__module__); esc(synchronize(__module__, args...; package=PKG_AMDGPU)); end
+macro synchronize_metal(args...) check_initialized(__module__); esc(synchronize(__module__, args...; package=PKG_METAL)); end
macro synchronize_threads(args...) check_initialized(__module__); esc(synchronize(__module__, args...; package=PKG_THREADS)); end
macro synchronize_polyester(args...) check_initialized(__module__); esc(synchronize(__module__, args...; package=PKG_POLYESTER)); end
@@ -158,6 +162,7 @@ end
function synchronize(caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_package(caller))
if (package == PKG_CUDA) synchronize_cuda(args...)
elseif (package == PKG_AMDGPU) synchronize_amdgpu(args...)
+ elseif (package == PKG_METAL) synchronize_metal(args...)
elseif (package == PKG_THREADS) synchronize_threads(args...)
elseif (package == PKG_POLYESTER) synchronize_polyester(args...)
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
@@ -236,6 +241,7 @@ function parallel_call_gpu(ranges::Union{Symbol,Expr}, nblocks::Union{Symbol,Exp
ranges = :(ParallelStencil.ParallelKernel.promote_ranges($ranges))
if (package == PKG_CUDA) int_type = INT_CUDA
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
+ elseif (package == PKG_METAL) int_type = INT_METAL
end
push!(kernelcall.args, ranges) #TODO: to enable indexing with other then Int64 something like the following but probably better in a function will also be necessary: push!(kernelcall.args, :(convert(Tuple{UnitRange{$int_type},UnitRange{$int_type},UnitRange{$int_type}}, $ranges)))
push!(kernelcall.args, :($int_type(length($ranges[1]))))
@@ -304,6 +310,7 @@ end
synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...); blocking=true))
synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...); blocking=true))
+synchronize_metal(args::Union{Symbol,Expr}...) = :(Metal.synchronize($(args...)))
synchronize_threads(args::Union{Symbol,Expr}...) = :(begin end)
synchronize_polyester(args::Union{Symbol,Expr}...) = :(begin end)
@@ -559,17 +566,22 @@ function create_gpu_call(package::Symbol, nblocks::Union{Symbol,Expr}, nthreads:
if !isnothing(shmem)
if (package == PKG_CUDA) shmem_expr = :(shmem = $shmem)
elseif (package == PKG_AMDGPU) shmem_expr = :(shmem = $shmem)
+ elseif (package == PKG_METAL) shmem_expr = nothing # No need to pass shared memory to Metal kernels.
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
- backend_kwargs_expr = (backend_kwargs_expr..., shmem_expr)
+ if package != PKG_METAL
+ backend_kwargs_expr = (backend_kwargs_expr..., shmem_expr)
+ end
end
if (package == PKG_CUDA) return :( CUDA.@cuda blocks=$nblocks threads=$nthreads stream=$stream $(backend_kwargs_expr...) $kernelcall; $synccall )
elseif (package == PKG_AMDGPU) return :( AMDGPU.@roc gridsize=$nblocks groupsize=$nthreads stream=$stream $(backend_kwargs_expr...) $kernelcall; $synccall )
+ elseif (package == PKG_METAL) return :( Metal.@metal groups=$nblocks threads=$nthreads queue=$stream $(backend_kwargs_expr...) $kernelcall; $synccall )
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
else
if (package == PKG_CUDA) return :( CUDA.@cuda launch=false $(backend_kwargs_expr...) $kernelcall) # NOTE: runtime arguments must be omitted when the kernel is not launched (backend_kwargs_expr must not contain any around time argument)
elseif (package == PKG_AMDGPU) return :( AMDGPU.@roc launch=false $(backend_kwargs_expr...) $kernelcall) # NOTE: ...
+ elseif (package == PKG_METAL) return :( Metal.@metal launch=false $(backend_kwargs_expr...) $kernelcall) # NOTE: ...
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
end
@@ -578,6 +590,7 @@ end
function create_synccall(package::Symbol, stream::Union{Symbol,Expr})
if (package == PKG_CUDA) synchronize_cuda(stream)
elseif (package == PKG_AMDGPU) synchronize_amdgpu(stream)
+ elseif (package == PKG_METAL) synchronize_metal(stream)
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
end
@@ -585,6 +598,7 @@ end
function default_stream(package)
if (package == PKG_CUDA) return :(CUDA.stream()) # Use the default stream of the task.
elseif (package == PKG_AMDGPU) return :(AMDGPU.stream()) # Use the default stream of the task.
+ elseif (package == PKG_METAL) return :(Metal.global_queue(Metal.device())) # Use the default queue of the task.
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
end
\ No newline at end of file
diff --git a/src/ParallelKernel/shared.jl b/src/ParallelKernel/shared.jl
index 1d00a8f1..4c77ba2f 100644
--- a/src/ParallelKernel/shared.jl
+++ b/src/ParallelKernel/shared.jl
@@ -11,12 +11,14 @@ gensym_world(tag::Expr, generator::Module) = gensym(string(tag, GENSYM_SEPARAT
const PKG_CUDA = :CUDA
const PKG_AMDGPU = :AMDGPU
+const PKG_METAL = :Metal
const PKG_THREADS = :Threads
const PKG_POLYESTER = :Polyester
const PKG_NONE = :PKG_NONE
-const SUPPORTED_PACKAGES = [PKG_THREADS, PKG_POLYESTER, PKG_CUDA, PKG_AMDGPU]
+const SUPPORTED_PACKAGES = [PKG_THREADS, PKG_POLYESTER, PKG_CUDA, PKG_AMDGPU, PKG_METAL]
const INT_CUDA = Int64 # NOTE: unsigned integers are not yet supported (proper negative offset and range is dealing missing)
const INT_AMDGPU = Int64 # NOTE: ...
+const INT_METAL = Int64 # NOTE: ...
const INT_POLYESTER = Int64 # NOTE: ...
const INT_THREADS = Int64 # NOTE: ...
const NTHREADS_X_MAX = 32
@@ -66,6 +68,7 @@ const ERRMSG_CHECK_LITERALTYPES = "the type given to 'literaltype' must be on
const CELLARRAY_BLOCKLENGTH = Dict(PKG_NONE => 0,
PKG_CUDA => 0,
PKG_AMDGPU => 0,
+ PKG_METAL => 0,
PKG_THREADS => 1,
PKG_POLYESTER => 1)
@@ -81,6 +84,7 @@ macro rangelengths() esc(:(($(RANGELENGTHS_VARNAMES...),))) end
function kernel_int_type(package::Symbol)
if (package == PKG_CUDA) int_type = INT_CUDA
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
+ elseif (package == PKG_METAL) int_type = INT_METAL
elseif (package == PKG_THREADS) int_type = INT_THREADS
elseif (package == PKG_POLYESTER) int_type = INT_POLYESTER
end
@@ -486,7 +490,7 @@ end
## FUNCTIONS/MACROS FOR DIVERSE SYNTAX SUGAR
iscpu(package) = return (package in (PKG_THREADS, PKG_POLYESTER))
-isgpu(package) = return (package in (PKG_CUDA, PKG_AMDGPU))
+isgpu(package) = return (package in (PKG_CUDA, PKG_AMDGPU, PKG_METAL))
## TEMPORARY FUNCTION DEFINITIONS TO BE MERGED IN MACROTOOLS (https://github.com/FluxML/MacroTools.jl/pull/173)
diff --git a/src/ParallelStencil.jl b/src/ParallelStencil.jl
index 2734020d..a46433a6 100644
--- a/src/ParallelStencil.jl
+++ b/src/ParallelStencil.jl
@@ -44,7 +44,7 @@ https://github.com/omlins/ParallelStencil.jl
- [`Data`](@ref)
!! note "Activation of GPU support"
- The support for GPU (CUDA or AMDGPU) is provided with extensions and requires therefore an explicit installation of the corresponding packages (CUDA.jl or AMDGPU.jl). Note that it is not required to import explicitly the corresponding module (CUDA or AMDGPU); this is automatically done by [`@init_parallel_stencil`](@ref).
+ The support for GPU (CUDA, AMDGPU or Metal) is provided with extensions and requires therefore an explicit installation of the corresponding packages (CUDA.jl, AMDGPU.jl or Metal.jl). Note that it is not required to import explicitly the corresponding module (CUDA, AMDGPU or Metal); this is automatically done by [`@init_parallel_stencil`](@ref).
To see a description of a macro or module type `?` (including the `@`) or `?`, respectively.
"""
diff --git a/src/init_parallel_stencil.jl b/src/init_parallel_stencil.jl
index 0cd790ad..23b1962b 100644
--- a/src/init_parallel_stencil.jl
+++ b/src/init_parallel_stencil.jl
@@ -28,7 +28,7 @@
Initialize the package ParallelStencil, giving access to its main functionality. Creates a module `Data` in the module where `@init_parallel_stencil` is called from. The module `Data` contains the types as `Data.Number`, `Data.Array` and `Data.CellArray` (type `?Data` *after* calling `@init_parallel_stencil` to see the full description of the module).
# Arguments
-- `package::Module`: the package used for parallelization (CUDA or AMDGPU for GPU, or Threads or Polyester for CPU).
+- `package::Module`: the package used for parallelization (CUDA, AMDGPU or Metal for GPU, or Threads or Polyester for CPU).
- `numbertype::DataType`: the type of numbers used by @zeros, @ones, @rand and @fill and in all array types of module `Data` (e.g. Float32 or Float64). It is contained in `Data.Number` after @init_parallel_stencil. The `numbertype` can be omitted if the other arguments are given as keyword arguments (in that case, the `numbertype` will have to be given explicitly when using the types provided by the module `Data`).
- `ndims::Integer`: the number of dimensions used for the stencil computations in the kernels: 1, 2 or 3 (overwritable in each kernel definition).
- `inbounds::Bool=false`: whether to apply `@inbounds` to the kernels by default (overwritable in each kernel definition).
diff --git a/src/kernel_language.jl b/src/kernel_language.jl
index 92d59e7a..6c7e4dd2 100644
--- a/src/kernel_language.jl
+++ b/src/kernel_language.jl
@@ -71,6 +71,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end
if (package == PKG_CUDA) int_type = INT_CUDA
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
+ elseif (package == PKG_METAL) int_type = INT_METAL
elseif (package == PKG_THREADS) int_type = INT_THREADS
end
body = eval_offsets(caller, body, indices, int_type)
diff --git a/src/parallel.jl b/src/parallel.jl
index d29baa11..b3bd7f71 100644
--- a/src/parallel.jl
+++ b/src/parallel.jl
@@ -34,8 +34,8 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `kernelcall`: a call to a kernel that is declared parallel.
!!! note "Advanced optional arguments"
- `ranges::Tuple{UnitRange{},UnitRange{},UnitRange{}} | Tuple{UnitRange{},UnitRange{}} | Tuple{UnitRange{}} | UnitRange{}`: the ranges of indices in each dimension for which computations must be performed.
- - `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- - `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
+ - `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA, AMDGPU or Metal was selected with [`@init_parallel_kernel`](@ref).
+ - `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA, AMDGPU or Metal was selected with [`@init_parallel_kernel`](@ref).
# Keyword arguments
- `memopt::Bool=false`: whether the kernel to be launched was generated with `memopt=true` (meaning the keyword was set in the kernel declaration).
@@ -44,7 +44,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `ad_mode=Enzyme.Reverse`: the automatic differentiation mode (see the documentation of Enzyme.jl for more information).
- `ad_annotations=()`: Enzyme variable annotations for automatic differentiation in the format `(=, =, ...)`, where `` can be a single variable or a tuple of variables (e.g., `ad_annotations=(Duplicated=B, Active=(a,b))`). Currently supported annotations are: $(keys(AD_SUPPORTED_ANNOTATIONS)).
- `configcall=kernelcall`: a call to a kernel that is declared parallel, which is used for determining the kernel launch parameters. This keyword is useful, e.g., for generic automatic differentiation using the low-level submodule [`AD`](@ref).
- - `backendkwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads and Polyester).
+ - `backendkwargs...`: keyword arguments to be passed further to CUDA, AMDGPU or Metal (ignored for Threads and Polyester).
!!! note "Performance note"
Kernel launch parameters are automatically defined with heuristics, where not defined with optional kernel arguments. For CUDA and AMDGPU, `nthreads` is typically set to (32,8,1) and `nblocks` accordingly to ensure that enough threads are launched.
@@ -86,14 +86,17 @@ macro parallel_async(args...) check_initialized(__module__); checkargs_parallel(
macro parallel_cuda(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_CUDA)); end
macro parallel_amdgpu(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_AMDGPU)); end
+macro parallel_metal(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_METAL)); end
macro parallel_threads(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_THREADS)); end
macro parallel_polyester(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_POLYESTER)); end
macro parallel_indices_cuda(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...; package=PKG_CUDA)); end
macro parallel_indices_amdgpu(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...; package=PKG_AMDGPU)); end
+macro parallel_indices_metal(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...; package=PKG_METAL)); end
macro parallel_indices_threads(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...; package=PKG_THREADS)); end
macro parallel_indices_polyester(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...; package=PKG_POLYESTER)); end
macro parallel_async_cuda(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...; package=PKG_CUDA)); end
macro parallel_async_amdgpu(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...; package=PKG_AMDGPU)); end
+macro parallel_async_metal(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...; package=PKG_METAL)); end
macro parallel_async_threads(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...; package=PKG_THREADS)); end
macro parallel_async_polyester(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...; package=PKG_POLYESTER)); end
@@ -350,7 +353,7 @@ end
## FUNCTIONS TO DETERMINE OPTIMIZATION PARAMETERS
-determine_nthreads_max_memopt(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_MAX_MEMOPT_AMDGPU : NTHREADS_MAX_MEMOPT_CUDA
+determine_nthreads_max_memopt(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_MAX_MEMOPT_AMDGPU : ((package == PKG_CUDA) ? NTHREADS_MAX_MEMOPT_CUDA : NTHREADS_MAX_MEMOPT_METAL)
determine_loopdim(indices::Union{Symbol,Expr}) = isa(indices,Expr) && (length(indices.args)==3) ? 3 : LOOPDIM_NONE # TODO: currently only loopdim=3 is supported.
compute_loopsize() = LOOPSIZE
diff --git a/src/shared.jl b/src/shared.jl
index 9f47b7c0..5b647da1 100644
--- a/src/shared.jl
+++ b/src/shared.jl
@@ -1,6 +1,6 @@
import MacroTools: @capture, postwalk, splitdef, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, iscpu, @isgpu, @iscpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, adjust_signatures, handle_indices_and_literals, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing
-import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_POLYESTER, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
+import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_METAL, INT_POLYESTER, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring
@@ -25,6 +25,7 @@ const LOOPSIZE = 16
const LOOPDIM_NONE = 0
const NTHREADS_MAX_MEMOPT_CUDA = 128
const NTHREADS_MAX_MEMOPT_AMDGPU = 256
+const NTHREADS_MAX_MEMOPT_METAL = 256
const USE_SHMEMHALO_DEFAULT = true
const USE_SHMEMHALO_1D_DEFAULT = true
const USE_FULLRANGE_DEFAULT = (false, false, true)
diff --git a/test/ParallelKernel/test_allocators.jl b/test/ParallelKernel/test_allocators.jl
index 07a701de..c0350d81 100644
--- a/test/ParallelKernel/test_allocators.jl
+++ b/test/ParallelKernel/test_allocators.jl
@@ -2,7 +2,7 @@ using Test
using CellArrays, StaticArrays
import ParallelStencil
using ParallelStencil.ParallelKernel
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_numbertype, NUMBERTYPE_NONE, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_numbertype, NUMBERTYPE_NONE, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
import ParallelStencil.ParallelKernel: @require, @prettystring, @gorgeousstring
import ParallelStencil.ParallelKernel: checkargs_CellType, _CellType
using ParallelStencil.ParallelKernel.FieldAllocators
@@ -19,10 +19,20 @@ end
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
@define_ROCCellArray
end
+@static if PKG_METAL in TEST_PACKAGES
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ @define_MtlCellArray
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not work in combination with @reset_parallel_kernel, because the macros from module Test alternate the order of evaluation, resulting in the Data module being replaced with an empty module before Data.Index is evaluated. If at some point the indexing varies depending on the used package, then something more sophisticated is needed here (e.g., wrapping the test for each package in a module and using then Data.Index everywhere).
-@static for package in TEST_PACKAGES eval(:(
+@static for package in TEST_PACKAGES
+
+eval(:(
@testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
@testset "1. @CellType macro" begin
@require !@is_initialized()
@@ -131,6 +141,19 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof(@fill(9, 2,3)) == typeof(AMDGPU.ROCArray(fill(convert(Float16, 9), 2,3)))
@test typeof(@fill(9, 2,3, eltype=Float64)) == typeof(AMDGPU.ROCArray(fill(convert(Float64, 9), 2,3)))
@test typeof(@fill(9, 2,3, eltype=DATA_INDEX)) == typeof(AMDGPU.ROCArray(fill(convert(DATA_INDEX, 9), 2,3)))
+ elseif $package == $PKG_METAL
+ @test typeof(@zeros(2,3)) == typeof(Metal.MtlArray(zeros(Float16,2,3)))
+ @test typeof(@zeros(2,3, eltype=Float32)) == typeof(Metal.MtlArray(zeros(Float32,2,3)))
+ @test typeof(@zeros(2,3, eltype=DATA_INDEX)) == typeof(Metal.MtlArray(zeros(DATA_INDEX,2,3)))
+ @test typeof(@ones(2,3)) == typeof(Metal.MtlArray(ones(Float16,2,3)))
+ @test typeof(@ones(2,3, eltype=Float32)) == typeof(Metal.MtlArray(ones(Float32,2,3)))
+ @test typeof(@ones(2,3, eltype=DATA_INDEX)) == typeof(Metal.MtlArray(ones(DATA_INDEX,2,3)))
+ @test typeof(@rand(2,3)) == typeof(Metal.MtlArray(rand(Float16,2,3)))
+ @test typeof(@rand(2,3, eltype=Float32)) == typeof(Metal.MtlArray(rand(Float32,2,3)))
+ @test typeof(@rand(2,3, eltype=DATA_INDEX)) == typeof(Metal.MtlArray(rand(DATA_INDEX,2,3)))
+ @test typeof(@fill(9, 2,3)) == typeof(Metal.MtlArray(fill(convert(Float16, 9), 2,3)))
+ @test typeof(@fill(9, 2,3, eltype=Float32)) == typeof(Metal.MtlArray(fill(convert(Float32, 9), 2,3)))
+ @test typeof(@fill(9, 2,3, eltype=DATA_INDEX)) == typeof(Metal.MtlArray(fill(convert(DATA_INDEX, 9), 2,3)))
else
@test typeof(@zeros(2,3)) == typeof(parentmodule($package).zeros(Float16,2,3))
@test typeof(@zeros(2,3, eltype=Float32)) == typeof(parentmodule($package).zeros(Float32,2,3))
@@ -182,6 +205,20 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test @trues(2,3, celldims=(3,4)) == CellArrays.fill!(ROCCellArray{T_Bool}(undef,2,3), trues((3,4)))
@test @zeros(2,3, celldims=(3,4), eltype=DATA_INDEX) == CellArrays.fill!(ROCCellArray{T_Index}(undef,2,3), T_Index(zeros((3,4))))
AMDGPU.allowscalar(false) #TODO: check how to do
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test @zeros(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Float16}(undef,2,3), T_Float16(zeros((3,4))))
+ @test @zeros(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(MtlCellArray{T_Float32}(undef,2,3), T_Float32(zeros((3,4))))
+ @test @ones(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Float16}(undef,2,3), T_Float16(ones((3,4))))
+ @test @ones(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(MtlCellArray{T_Float32}(undef,2,3), T_Float32(ones((3,4))))
+ @test typeof(@rand(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof(@rand(2,3, celldims=(3,4), eltype=Float32)) == typeof(MtlCellArray{T_Float32,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4), eltype=Float32)) == typeof(MtlCellArray{T_Float32,0}(undef,2,3))
+ @test @falses(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Bool}(undef,2,3), falses((3,4)))
+ @test @trues(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Bool}(undef,2,3), trues((3,4)))
+ @test @zeros(2,3, celldims=(3,4), eltype=DATA_INDEX) == CellArrays.fill!(MtlCellArray{T_Index}(undef,2,3), T_Index(zeros((3,4))))
+ Metal.allowscalar(false)
else
@test @zeros(2,3, celldims=(3,4)) == CellArrays.fill!(CPUCellArray{T_Float16}(undef,2,3), T_Float16(zeros((3,4))))
@test @zeros(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(CPUCellArray{T_Float32}(undef,2,3), T_Float32(zeros((3,4))))
@@ -221,6 +258,18 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof(@fill(9, 2,3, celltype=SymmetricTensor2D)) == typeof(ROCCellArray{SymmetricTensor2D,0}(undef,2,3))
@test @zeros(2,3, celltype=SymmetricTensor2D_Index) == CellArrays.fill!(ROCCellArray{SymmetricTensor2D_Index}(undef,2,3), SymmetricTensor2D_Index(zeros(3)))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test @zeros(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(zeros(3)))
+ @test @zeros(2,3, celltype=SymmetricTensor3D) == CellArrays.fill!(MtlCellArray{SymmetricTensor3D}(undef,2,3), SymmetricTensor3D(zeros(6)))
+ @test @zeros(2,3, celltype=Tensor2D) == CellArrays.fill!(MtlCellArray{Tensor2D}(undef,2,3), Tensor2D(zeros((2,2,2,2))))
+ @test @zeros(2,3, celltype=SymmetricTensor2D_T{Float32}) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D_T{Float32}}(undef,2,3), SymmetricTensor2D_T{Float64}(zeros(3)))
+ @test @zeros(2,3, celltype=SymmetricTensor2D_Float32) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D_Float32}(undef,2,3), SymmetricTensor2D_Float32(zeros(3)))
+ @test @ones(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(ones(3)))
+ @test typeof(@rand(2,3, celltype=SymmetricTensor2D)) == typeof(MtlCellArray{SymmetricTensor2D,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celltype=SymmetricTensor2D)) == typeof(MtlCellArray{SymmetricTensor2D,0}(undef,2,3))
+ @test @zeros(2,3, celltype=SymmetricTensor2D_Index) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D_Index}(undef,2,3), SymmetricTensor2D_Index(zeros(3)))
+ Metal.allowscalar(false)
else
@test @zeros(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(CPUCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(zeros(3)))
@test @zeros(2,3, celltype=SymmetricTensor3D) == CellArrays.fill!(CPUCellArray{SymmetricTensor3D}(undef,2,3), SymmetricTensor3D(zeros(6)))
@@ -267,6 +316,12 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof(@rand(2,3, eltype=Float64)) == typeof(AMDGPU.ROCArray(rand(Float64,2,3)))
@test typeof(@fill(9, 2,3, eltype=Float64)) == typeof(AMDGPU.ROCArray(fill(convert(Float64, 9), 2,3)))
@test typeof(@zeros(2,3, eltype=DATA_INDEX)) == typeof(AMDGPU.ROCArray(zeros(DATA_INDEX,2,3)))
+ elseif $package == $PKG_METAL
+ @test typeof(@zeros(2,3, eltype=Float32)) == typeof(Metal.MtlArray(zeros(Float32,2,3)))
+ @test typeof(@ones(2,3, eltype=Float32)) == typeof(Metal.MtlArray(ones(Float32,2,3)))
+ @test typeof(@rand(2,3, eltype=Float32)) == typeof(Metal.MtlArray(rand(Float32,2,3)))
+ @test typeof(@fill(9, 2,3, eltype=Float32)) == typeof(Metal.MtlArray(fill(convert(Float32, 9), 2,3)))
+ @test typeof(@zeros(2,3, eltype=DATA_INDEX)) == typeof(Metal.MtlArray(zeros(DATA_INDEX,2,3)))
else
@test typeof(@zeros(2,3, eltype=Float32)) == typeof(zeros(Float32,2,3))
@test typeof(@ones(2,3, eltype=Float32)) == typeof(ones(Float32,2,3))
@@ -300,6 +355,15 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test @falses(2,3, celldims=(3,4)) == CellArrays.fill!(ROCCellArray{T_Bool}(undef,2,3), falses((3,4)))
@test @trues(2,3, celldims=(3,4)) == CellArrays.fill!(ROCCellArray{T_Bool}(undef,2,3), trues((3,4)))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test @zeros(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(MtlCellArray{T_Float32}(undef,2,3), T_Float32(zeros((3,4))))
+ @test @ones(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(MtlCellArray{T_Float32}(undef,2,3), T_Float32(ones((3,4))))
+ @test typeof(@rand(2,3, celldims=(3,4), eltype=Float32)) == typeof(MtlCellArray{T_Float32,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4), eltype=Float32)) == typeof(MtlCellArray{T_Float32,0}(undef,2,3))
+ @test @falses(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Bool}(undef,2,3), falses((3,4)))
+ @test @trues(2,3, celldims=(3,4)) == CellArrays.fill!(MtlCellArray{T_Bool}(undef,2,3), trues((3,4)))
+ Metal.allowscalar(false)
else
@test @zeros(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(CPUCellArray{T_Float32}(undef,2,3), T_Float32(zeros((3,4))))
@test @ones(2,3, celldims=(3,4), eltype=Float32) == CellArrays.fill!(CPUCellArray{T_Float32}(undef,2,3), T_Float32(ones((3,4))))
@@ -332,6 +396,17 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof(@rand(2,3, celltype=SymmetricTensor2D)) == typeof(ROCCellArray{SymmetricTensor2D,0}(undef,2,3))
@test typeof(@fill(9, 2,3, celltype=SymmetricTensor2D)) == typeof(ROCCellArray{SymmetricTensor2D,0}(undef,2,3))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test @zeros(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(zeros(3)))
+ @test @zeros(2,3, celltype=SymmetricTensor3D) == CellArrays.fill!(MtlCellArray{SymmetricTensor3D}(undef,2,3), SymmetricTensor3D(zeros(6)))
+ @test @zeros(2,3, celltype=Tensor2D) == CellArrays.fill!(MtlCellArray{Tensor2D}(undef,2,3), Tensor2D(zeros((2,2,2,2))))
+ @test @zeros(2,3, celltype=SymmetricTensor2D_T{Float32}) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D_T{Float32}}(undef,2,3), SymmetricTensor2D_T{Float32}(zeros(3)))
+ @test @zeros(2,3, celltype=SymmetricTensor2D_Float32) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D_Float32}(undef,2,3), SymmetricTensor2D_Float32(zeros(3)))
+ @test @ones(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(MtlCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(ones(3)))
+ @test typeof(@rand(2,3, celltype=SymmetricTensor2D)) == typeof(MtlCellArray{SymmetricTensor2D,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celltype=SymmetricTensor2D)) == typeof(MtlCellArray{SymmetricTensor2D,0}(undef,2,3))
+ Metal.allowscalar(false)
else
@test @zeros(2,3, celltype=SymmetricTensor2D) == CellArrays.fill!(CPUCellArray{SymmetricTensor2D}(undef,2,3), SymmetricTensor2D(zeros(3)))
@test @zeros(2,3, celltype=SymmetricTensor3D) == CellArrays.fill!(CPUCellArray{SymmetricTensor3D}(undef,2,3), SymmetricTensor3D(zeros(6)))
@@ -370,6 +445,15 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof( @falses(2,3, celldims=(3,4))) == typeof(ROCCellArray{T_Bool, 0}(undef,2,3))
@test typeof( @trues(2,3, celldims=(3,4))) == typeof(ROCCellArray{T_Bool, 0}(undef,2,3))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test typeof( @zeros(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof( @ones(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof( @rand(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Float16,0}(undef,2,3))
+ @test typeof( @falses(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Bool, 0}(undef,2,3))
+ @test typeof( @trues(2,3, celldims=(3,4))) == typeof(MtlCellArray{T_Bool, 0}(undef,2,3))
+ Metal.allowscalar(false)
else
@test typeof( @zeros(2,3, celldims=(3,4))) == typeof(CPUCellArray{T_Float16,1}(undef,2,3))
@test typeof( @ones(2,3, celldims=(3,4))) == typeof(CPUCellArray{T_Float16,1}(undef,2,3))
@@ -410,6 +494,21 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof( @falses(2,3, celldims=(3,4), blocklength=3)) == typeof(ROCCellArray{T_Bool, 3}(undef,2,3))
@test typeof( @trues(2,3, celldims=(3,4), blocklength=3)) == typeof(ROCCellArray{T_Bool, 3}(undef,2,3))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test typeof( @zeros(2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Float16,1}(undef,2,3))
+ @test typeof( @ones(2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Float16,1}(undef,2,3))
+ @test typeof( @rand(2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Float16,1}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Float16,1}(undef,2,3))
+ @test typeof( @falses(2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Bool, 1}(undef,2,3))
+ @test typeof( @trues(2,3, celldims=(3,4), blocklength=1)) == typeof(MtlCellArray{T_Bool, 1}(undef,2,3))
+ @test typeof( @zeros(2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Float16,3}(undef,2,3))
+ @test typeof( @ones(2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Float16,3}(undef,2,3))
+ @test typeof( @rand(2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Float16,3}(undef,2,3))
+ @test typeof(@fill(9, 2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Float16,3}(undef,2,3))
+ @test typeof( @falses(2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Bool, 3}(undef,2,3))
+ @test typeof( @trues(2,3, celldims=(3,4), blocklength=3)) == typeof(MtlCellArray{T_Bool, 3}(undef,2,3))
+ Metal.allowscalar(false)
else
@test typeof( @zeros(2,3, celldims=(3,4), blocklength=0)) == typeof(CPUCellArray{T_Float16,0}(undef,2,3))
@test typeof( @ones(2,3, celldims=(3,4), blocklength=0)) == typeof(CPUCellArray{T_Float16,0}(undef,2,3))
@@ -449,6 +548,14 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
@test typeof(@fill(solid, 2,3, celldims=(3,4), eltype=Phase)) == typeof(ROCCellArray{T_Phase,0}(undef,2,3))
@test typeof(@fill(@rand(3,4,eltype=Phase), 2,3, celldims=(3,4), eltype=Phase)) == typeof(ROCCellArray{T_Phase,0}(undef,2,3))
AMDGPU.allowscalar(false)
+ elseif $package == $PKG_METAL
+ Metal.allowscalar(true)
+ @test typeof(@rand(2,3, eltype=Phase)) == typeof(Metal.MtlArray(rand(Phase, 2,3)))
+ # @test typeof(@rand(2,3, celldims=(3,4), eltype=Phase)) == typeof(MtlCellArray{T_Phase,0}(undef,2,3)) # TODO fails because of bug in Metal.jl RNG implementation
+ @test typeof(@fill(solid, 2,3, eltype=Phase)) == typeof(Metal.MtlArray(rand(Phase, 2,3)))
+ @test typeof(@fill(solid, 2,3, celldims=(3,4), eltype=Phase)) == typeof(MtlCellArray{T_Phase,0}(undef,2,3))
+ @test typeof(@fill(@rand(3,4,eltype=Phase), 2,3, celldims=(3,4), eltype=Phase)) == typeof(MtlCellArray{T_Phase,0}(undef,2,3))
+ Metal.allowscalar(false)
else
@test typeof(@rand(2,3, eltype=Phase)) == typeof(rand(Phase, 2,3))
@test typeof(@rand(2,3, celldims=(3,4), eltype=Phase)) == typeof(CPUCellArray{T_Phase,1}(undef,2,3))
@@ -651,30 +758,3 @@ const DATA_INDEX = ParallelStencil.INT_THREADS # TODO: using Data.Index does not
end;
end;
)) end == nothing || true;
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/test/ParallelKernel/test_hide_communication.jl b/test/ParallelKernel/test_hide_communication.jl
index 4cbc2e1c..48171b19 100644
--- a/test/ParallelKernel/test_hide_communication.jl
+++ b/test/ParallelKernel/test_hide_communication.jl
@@ -1,7 +1,7 @@
using Test
import ParallelStencil
using ParallelStencil.ParallelKernel
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
import ParallelStencil.ParallelKernel: @require, @prettyexpand, @gorgeousexpand, gorgeousstring, @isgpu
import ParallelStencil.ParallelKernel: checkargs_hide_communication, hide_communication_gpu
using ParallelStencil.ParallelKernel.Exceptions
@@ -14,13 +14,29 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) ? continue : nothing # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@testset "1. hide_communication macro" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized()
@testset "@hide_communication boundary_width block (macro expansion)" begin
@static if @isgpu($package)
@@ -164,7 +180,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@testset "2. Exceptions" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized
@testset "arguments @hide_communication" begin
@test_throws ArgumentError checkargs_hide_communication(:boundary_width, :block) # Error: the last argument must be a code block.
@@ -204,4 +220,6 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@reset_parallel_kernel()
end;
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/ParallelKernel/test_init_parallel_kernel.jl b/test/ParallelKernel/test_init_parallel_kernel.jl
index d2599597..e26308bb 100644
--- a/test/ParallelKernel/test_init_parallel_kernel.jl
+++ b/test/ParallelKernel/test_init_parallel_kernel.jl
@@ -1,7 +1,7 @@
using Test
import ParallelStencil
using ParallelStencil.ParallelKernel
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_package, @get_numbertype, @get_inbounds, NUMBERTYPE_NONE, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, SCALARTYPES, ARRAYTYPES, FIELDTYPES
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_package, @get_numbertype, @get_inbounds, NUMBERTYPE_NONE, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER, SCALARTYPES, ARRAYTYPES, FIELDTYPES
import ParallelStencil.ParallelKernel: @require, @symbols
import ParallelStencil.ParallelKernel: extract_posargs_init, extract_kwargs_init, check_already_initialized, set_initialized, is_initialized, check_initialized
using ParallelStencil.ParallelKernel.Exceptions
@@ -14,6 +14,17 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
@static for package in TEST_PACKAGES eval(:(
diff --git a/test/ParallelKernel/test_kernel_language.jl b/test/ParallelKernel/test_kernel_language.jl
index 3b6da0dc..fe4ffd76 100644
--- a/test/ParallelKernel/test_kernel_language.jl
+++ b/test/ParallelKernel/test_kernel_language.jl
@@ -1,7 +1,7 @@
using Test
import ParallelStencil
using ParallelStencil.ParallelKernel
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_POLYESTER
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER
import ParallelStencil.ParallelKernel: @require, @prettystring, @iscpu
import ParallelStencil.ParallelKernel: checknoargs, checkargs_sharedMem, Dim3
using ParallelStencil.ParallelKernel.Exceptions
@@ -14,13 +14,25 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) ? continue : nothing # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@testset "1. kernel language macros" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized()
@testset "mapping to package" begin
if $package == $PKG_CUDA
@@ -29,7 +41,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@test @prettystring(1, @blockDim()) == "CUDA.blockDim()"
@test @prettystring(1, @threadIdx()) == "CUDA.threadIdx()"
@test @prettystring(1, @sync_threads()) == "CUDA.sync_threads()"
- @test @prettystring(1, @sharedMem(Float32, (2,3))) == "CUDA.@cuDynamicSharedMem Float32 (2, 3)"
+ @test @prettystring(1, @sharedMem($precision, (2,3))) == "CUDA.@cuDynamicSharedMem $(nameof($precision)) (2, 3)"
# @test @prettystring(1, @pk_show()) == "CUDA.@cushow"
# @test @prettystring(1, @pk_println()) == "CUDA.@cuprintln"
elseif $package == $AMDGPU
@@ -38,16 +50,25 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@test @prettystring(1, @blockDim()) == "AMDGPU.workgroupDim()"
@test @prettystring(1, @threadIdx()) == "AMDGPU.workitemIdx()"
@test @prettystring(1, @sync_threads()) == "AMDGPU.sync_workgroup()"
- # @test @prettystring(1, @sharedMem(Float32, (2,3))) == "" #TODO: not yet supported for AMDGPU
+ # @test @prettystring(1, @sharedMem($precision, (2,3))) == "" #TODO: not yet supported for AMDGPU
# @test @prettystring(1, @pk_show()) == "CUDA.@cushow" #TODO: not yet supported for AMDGPU
# @test @prettystring(1, @pk_println()) == "AMDGPU.@rocprintln"
+ elseif $package == $PKG_METAL
+ @test @prettystring(1, @gridDim()) == "Metal.threadgroups_per_grid_3d()"
+ @test @prettystring(1, @blockIdx()) == "Metal.threadgroup_position_in_grid_3d()"
+ @test @prettystring(1, @blockDim()) == "Metal.threads_per_threadgroup_3d()"
+ @test @prettystring(1, @threadIdx()) == "Metal.thread_position_in_threadgroup_3d()"
+ @test @prettystring(1, @sync_threads()) == "Metal.threadgroup_barrier(; flag = Metal.MemoryFlagThreadGroup)"
+ @test @prettystring(1, @sharedMem($precision, (2,3))) == "ParallelStencil.ParallelKernel.@sharedMem_metal $(nameof($precision)) (2, 3)"
+ # @test @prettystring(1, @pk_show()) == "Metal.@mtlshow" #TODO: not yet supported for Metal
+ # @test @prettystring(1, @pk_println()) == "Metal.@mtlprintln" #TODO: not yet supported for Metal
elseif @iscpu($package)
@test @prettystring(1, @gridDim()) == "ParallelStencil.ParallelKernel.@gridDim_cpu"
@test @prettystring(1, @blockIdx()) == "ParallelStencil.ParallelKernel.@blockIdx_cpu"
@test @prettystring(1, @blockDim()) == "ParallelStencil.ParallelKernel.@blockDim_cpu"
@test @prettystring(1, @threadIdx()) == "ParallelStencil.ParallelKernel.@threadIdx_cpu"
@test @prettystring(1, @sync_threads()) == "ParallelStencil.ParallelKernel.@sync_threads_cpu"
- @test @prettystring(1, @sharedMem(Float32, (2,3))) == "ParallelStencil.ParallelKernel.@sharedMem_cpu Float32 (2, 3)"
+ @test @prettystring(1, @sharedMem($precision, (2,3))) == "ParallelStencil.ParallelKernel.@sharedMem_cpu $(nameof($precision)) (2, 3)"
# @test @prettystring(1, @pk_show()) == "Base.@show"
# @test @prettystring(1, @pk_println()) == "Base.println()"
end;
@@ -117,7 +138,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@testset "shared memory (allocation)" begin
@static if @iscpu($package)
- @test typeof(@sharedMem(Float32,(2,3))) == typeof(ParallelStencil.ParallelKernel.MArray{Tuple{2,3}, Float32, length((2,3)), prod((2,3))}(undef))
+ @test typeof(@sharedMem($precision,(2,3))) == typeof(ParallelStencil.ParallelKernel.MArray{Tuple{2,3}, $precision, length((2,3)), prod((2,3))}(undef))
@test typeof(@sharedMem(Bool,(2,3,4))) == typeof(ParallelStencil.ParallelKernel.MArray{Tuple{2,3,4}, Bool, length((2,3,4)), prod((2,3,4))}(undef))
end;
end;
@@ -193,7 +214,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@reset_parallel_kernel()
end;
@testset "2. Exceptions" begin
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized
@testset "no arguments" begin
@test_throws ArgumentError checknoargs(:(something)); # Error: length(args) != 0
@@ -206,4 +227,6 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@reset_parallel_kernel()
end;
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/ParallelKernel/test_parallel.jl b/test/ParallelKernel/test_parallel.jl
index e69d64c8..a6585847 100644
--- a/test/ParallelKernel/test_parallel.jl
+++ b/test/ParallelKernel/test_parallel.jl
@@ -3,7 +3,7 @@ import ParallelStencil
using Enzyme
using ParallelStencil.ParallelKernel
import ParallelStencil.ParallelKernel.AD
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_POLYESTER, INDICES, ARRAYTYPES, FIELDTYPES
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER, INDICES, ARRAYTYPES, FIELDTYPES
import ParallelStencil.ParallelKernel: @require, @prettystring, @gorgeousstring, @isgpu, @iscpu
import ParallelStencil.ParallelKernel: checkargs_parallel, checkargs_parallel_indices, parallel_indices, maxsize
using ParallelStencil.ParallelKernel.Exceptions
@@ -16,6 +16,10 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+end
@static if PKG_POLYESTER in TEST_PACKAGES
import Polyester
end
@@ -24,11 +28,17 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
macro compute(A) esc(:($(INDICES[1]) + ($(INDICES[2])-1)*size($A,1))) end
macro compute_with_aliases(A) esc(:(ix + (iz -1)*size($A,1))) end
import Enzyme
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) ? continue : nothing # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@testset "1. parallel macros" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized()
@testset "@parallel" begin
@static if $package == $PKG_CUDA
@@ -55,6 +65,18 @@ import Enzyme
@test occursin("AMDGPU.@roc gridsize = nblocks groupsize = nthreads stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads stream=mystream f(A))
@test occursin("AMDGPU.@roc gridsize = nblocks groupsize = nthreads stream = mystream f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
+ elseif $package == $PKG_METAL
+ call = @prettystring(1, @parallel f(A))
+ @test occursin("Metal.@metal groups = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
+ @test occursin("Metal.synchronize(Metal.global_queue(Metal.device()))", call)
+ call = @prettystring(1, @parallel ranges f(A))
+ @test occursin("Metal.@metal groups = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
+ call = @prettystring(1, @parallel nblocks nthreads f(A))
+ @test occursin("Metal.@metal groups = nblocks threads = nthreads queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
+ call = @prettystring(1, @parallel ranges nblocks nthreads f(A))
+ @test occursin("Metal.@metal groups = nblocks threads = nthreads queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
+ call = @prettystring(1, @parallel nblocks nthreads stream=mystream f(A))
+ @test occursin("Metal.@metal groups = nblocks threads = nthreads queue = mystream f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
elseif @iscpu($package)
@test @prettystring(1, @parallel f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))"
@test @prettystring(1, @parallel ranges f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))"
@@ -70,7 +92,7 @@ import Enzyme
@testset "maxsize" begin
struct BitstypeStruct
x::Int
- y::Float64
+ y::Float32
end
@test maxsize([9 9; 9 9; 9 9]) == (3, 2, 1)
@test maxsize(8) == (1, 1, 1)
@@ -101,8 +123,8 @@ import Enzyme
B̄ = @ones(N)
A_ref = Array(A)
B_ref = Array(B)
- Ā_ref = ones(N)
- B̄_ref = ones(N)
+ Ā_ref = ones($precision, N)
+ B̄_ref = ones($precision, N)
@parallel_indices (ix) function f!(A, B, a)
A[ix] += a * B[ix] * 100.65
return
@@ -489,19 +511,21 @@ import Enzyme
@reset_parallel_kernel()
end;
@testset "2. parallel macros (literal conversion)" begin
- @testset "@parallel_indices (Float64)" begin
- @require !@is_initialized()
- @init_parallel_kernel($package, Float64)
- @require @is_initialized()
- expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = A[ix] + 1.0f0; return))
- @test occursin("A[ix] = A[ix] + 1.0\n", expansion)
- @reset_parallel_kernel()
- end;
+ if $package != $PKG_METAL
+ @testset "@parallel_indices (Float64)" begin
+ @require !@is_initialized()
+ @init_parallel_kernel($package, Float64)
+ @require @is_initialized()
+ expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = A[ix] + 1.0; return))
+ @test occursin("A[ix] = A[ix] + 1.0\n", expansion)
+ @reset_parallel_kernel()
+ end;
+ end
@testset "@parallel_indices (Float32)" begin
@require !@is_initialized()
@init_parallel_kernel($package, Float32)
@require @is_initialized()
- expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = A[ix] + 1.0; return))
+ expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = A[ix] + 1.0f0; return))
@test occursin("A[ix] = A[ix] + 1.0f0\n", expansion)
@reset_parallel_kernel()
end;
@@ -513,14 +537,16 @@ import Enzyme
@test occursin("A[ix] = A[ix] + Float16(1.0)\n", expansion)
@reset_parallel_kernel()
end;
- @testset "@parallel_indices (ComplexF64)" begin
- @require !@is_initialized()
- @init_parallel_kernel($package, ComplexF64)
- @require @is_initialized()
- expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = 2.0f0 - 1.0f0im - A[ix] + 1.0f0; return))
- @test occursin("A[ix] = ((2.0 - 1.0im) - A[ix]) + 1.0\n", expansion)
- @reset_parallel_kernel()
- end;
+ if $package != $PKG_METAL
+ @testset "@parallel_indices (ComplexF64)" begin
+ @require !@is_initialized()
+ @init_parallel_kernel($package, ComplexF64)
+ @require @is_initialized()
+ expansion = @gorgeousstring(@parallel_indices (ix) f!(A) = (A[ix] = 2.0f0 - 1.0f0im - A[ix] + 1.0f0; return))
+ @test occursin("A[ix] = ((2.0 - 1.0im) - A[ix]) + 1.0\n", expansion)
+ @reset_parallel_kernel()
+ end;
+ end
@testset "@parallel_indices (ComplexF32)" begin
@require !@is_initialized()
@init_parallel_kernel($package, ComplexF32)
@@ -541,7 +567,7 @@ import Enzyme
@testset "3. global defaults" begin
@testset "inbounds=true" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64, inbounds=true)
+ @init_parallel_kernel($package, $precision, inbounds=true)
@require @is_initialized
expansion = @prettystring(1, @parallel_indices (ix) inbounds=true f(A) = (2*A; return))
@test occursin("Base.@inbounds begin", expansion)
@@ -602,7 +628,7 @@ import Enzyme
end;
@testset "5. Exceptions" begin
@require !@is_initialized()
- @init_parallel_kernel($package, Float64)
+ @init_parallel_kernel($package, $precision)
@require @is_initialized
@testset "arguments @parallel" begin
@test_throws ArgumentError checkargs_parallel(); # Error: isempty(args)
@@ -637,4 +663,6 @@ import Enzyme
@reset_parallel_kernel()
end;
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/ParallelKernel/test_reset_parallel_kernel.jl b/test/ParallelKernel/test_reset_parallel_kernel.jl
index b3e3ae0a..fe2cc01a 100644
--- a/test/ParallelKernel/test_reset_parallel_kernel.jl
+++ b/test/ParallelKernel/test_reset_parallel_kernel.jl
@@ -1,7 +1,7 @@
using Test
import ParallelStencil
using ParallelStencil.ParallelKernel
-import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_package, @get_numbertype, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_NONE, NUMBERTYPE_NONE
+import ParallelStencil.ParallelKernel: @reset_parallel_kernel, @is_initialized, @get_package, @get_numbertype, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE
import ParallelStencil.ParallelKernel: @require, @symbols
TEST_PACKAGES = SUPPORTED_PACKAGES
@static if PKG_CUDA in TEST_PACKAGES
@@ -12,6 +12,17 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
@static for package in TEST_PACKAGES eval(:(
diff --git a/test/runtests.jl b/test/runtests.jl
index 85ba20e5..cb847afd 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,9 +2,10 @@
push!(LOAD_PATH, "../src")
import ParallelStencil # Precompile it.
-import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL
@static if (PKG_CUDA in SUPPORTED_PACKAGES) import CUDA end
@static if (PKG_AMDGPU in SUPPORTED_PACKAGES) import AMDGPU end
+@static if (PKG_METAL in SUPPORTED_PACKAGES && Sys.isapple()) import Metal end
excludedfiles = [ "test_excluded.jl", "test_incremental_compilation.jl"]; # TODO: test_incremental_compilation has to be deactivated until Polyester support released
@@ -25,6 +26,10 @@ function runtests()
@warn "Test Skip: All AMDGPU tests will be skipped because AMDGPU is not functional (if this is unexpected type `import AMDGPU; AMDGPU.functional()` to debug your AMDGPU installation)."
end
+ if (PKG_METAL in SUPPORTED_PACKAGES && (!Sys.isapple() || !Metal.functional()))
+ @warn "Test Skip: All Metal tests will be skipped because Metal is not functional (if this is unexpected type `import Metal; Metal.functional()` to debug your Metal installation)."
+ end
+
for f in testfiles
println("")
if basename(f) ∈ excludedfiles
diff --git a/test/test_FiniteDifferences1D.jl b/test/test_FiniteDifferences1D.jl
index bd058592..01f7a120 100644
--- a/test/test_FiniteDifferences1D.jl
+++ b/test/test_FiniteDifferences1D.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
import ParallelStencil: @require
using ParallelStencil.FiniteDifferences1D
TEST_PACKAGES = SUPPORTED_PACKAGES
@@ -12,12 +12,28 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) && continue # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 1)
+ @init_parallel_stencil($package, $precision, 1)
@require @is_initialized()
nx = 7
A = @rand(nx );
@@ -40,7 +56,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@testset "averages" begin
@parallel av!(R, Ax) = (@all(R) = @av(Ax); return)
- R.=0; @parallel av!(R, Ax); @test all(Array(R .== (Ax[1:end-1].+Ax[2:end]).*0.5))
+ R.=0; @parallel av!(R, Ax); @test all(Array(R .== (Ax[1:end-1].+Ax[2:end]).*$precision(0.5)))
end;
@testset "harmonic averages" begin
@parallel harm!(R, Ax) = (@all(R) = @harm(Ax); return)
@@ -71,4 +87,7 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@reset_parallel_stencil()
end;
-)) end == nothing || true;
+))
+
+end
+end == nothing || true;
diff --git a/test/test_FiniteDifferences2D.jl b/test/test_FiniteDifferences2D.jl
index e836f3a8..d70b92a2 100644
--- a/test/test_FiniteDifferences2D.jl
+++ b/test/test_FiniteDifferences2D.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
import ParallelStencil: @require
using ParallelStencil.FiniteDifferences2D
TEST_PACKAGES = SUPPORTED_PACKAGES
@@ -12,12 +12,28 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) && continue # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 2)
+ @init_parallel_stencil($package, $precision, 2)
@require @is_initialized()
nx, ny = 7, 5
A = @rand(nx, ny );
@@ -66,11 +82,11 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@parallel av_ya!(R, Ay) = (@all(R) = @av_ya(Ay); return)
@parallel av_xi!(R, Axyy) = (@all(R) = @av_xi(Axyy); return)
@parallel av_yi!(R, Axxy) = (@all(R) = @av_yi(Axxy); return)
- R.=0; @parallel av!(R, Axy); @test all(Array(R .== (Axy[1:end-1,1:end-1].+Axy[2:end,1:end-1].+Axy[1:end-1,2:end].+Axy[2:end,2:end])*0.25))
- R.=0; @parallel av_xa!(R, Ax); @test all(Array(R .== (Ax[2:end, :].+Ax[1:end-1, :]).*0.5))
- R.=0; @parallel av_ya!(R, Ay); @test all(Array(R .== (Ay[ :,2:end].+Ay[ :,1:end-1]).*0.5))
- R.=0; @parallel av_xi!(R, Axyy); @test all(Array(R .== (Axyy[2:end ,2:end-1].+Axyy[1:end-1,2:end-1]).*0.5))
- R.=0; @parallel av_yi!(R, Axxy); @test all(Array(R .== (Axxy[2:end-1,2:end ].+Axxy[2:end-1,1:end-1]).*0.5))
+ R.=0; @parallel av!(R, Axy); @test all(Array(R .== (Axy[1:end-1,1:end-1].+Axy[2:end,1:end-1].+Axy[1:end-1,2:end].+Axy[2:end,2:end]).*$precision(0.25)))
+ R.=0; @parallel av_xa!(R, Ax); @test all(Array(R .== (Ax[2:end, :].+Ax[1:end-1, :]).*$precision(0.5)))
+ R.=0; @parallel av_ya!(R, Ay); @test all(Array(R .== (Ay[ :,2:end].+Ay[ :,1:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_xi!(R, Axyy); @test all(Array(R .== (Axyy[2:end ,2:end-1].+Axyy[1:end-1,2:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_yi!(R, Axxy); @test all(Array(R .== (Axxy[2:end-1,2:end ].+Axxy[2:end-1,1:end-1]).*$precision(0.5)))
end;
@testset "harmonic averages" begin
@parallel harm!(R, Axy) = (@all(R) = @harm(Axy); return)
@@ -112,4 +128,6 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@reset_parallel_stencil()
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/test_FiniteDifferences3D.jl b/test/test_FiniteDifferences3D.jl
index 056ffae0..11db69db 100644
--- a/test/test_FiniteDifferences3D.jl
+++ b/test/test_FiniteDifferences3D.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
import ParallelStencil: @require
using ParallelStencil.FiniteDifferences3D
TEST_PACKAGES = SUPPORTED_PACKAGES
@@ -12,12 +12,28 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) && continue # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 3)
+ @init_parallel_stencil($package, $precision, 3)
@require @is_initialized()
nx, ny, nz = 7, 5, 6
A = @rand(nx , ny , nz );
@@ -96,19 +112,19 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
@parallel av_xyi!(R, Axyzz) = (@all(R) = @av_xyi(Axyzz); return)
@parallel av_xzi!(R, Axyyz) = (@all(R) = @av_xzi(Axyyz); return)
@parallel av_yzi!(R, Axxyz) = (@all(R) = @av_yzi(Axxyz); return)
- R.=0; @parallel av!(R, Axyz); @test all(Array(R .== (Axyz[1:end-1,1:end-1,1:end-1].+Axyz[2:end,1:end-1,1:end-1].+Axyz[2:end,2:end,1:end-1].+Axyz[2:end,2:end,2:end].+Axyz[1:end-1,2:end,2:end].+Axyz[1:end-1,1:end-1,2:end].+Axyz[2:end,1:end-1,2:end].+Axyz[1:end-1,2:end,1:end-1])*0.125))
- R.=0; @parallel av_xa!(R, Ax); @test all(Array(R .== (Ax[2:end, :, :].+Ax[1:end-1, :, :]).*0.5))
- R.=0; @parallel av_ya!(R, Ay); @test all(Array(R .== (Ay[ :,2:end, :].+Ay[ :,1:end-1, :]).*0.5))
- R.=0; @parallel av_za!(R, Az); @test all(Array(R .== (Az[ :, :,2:end].+Az[ :, :,1:end-1]).*0.5))
- R.=0; @parallel av_xi!(R, Axyyzz); @test all(Array(R .== (Axyyzz[2:end ,2:end-1,2:end-1].+Axyyzz[1:end-1,2:end-1,2:end-1]).*0.5))
- R.=0; @parallel av_yi!(R, Axxyzz); @test all(Array(R .== (Axxyzz[2:end-1,2:end ,2:end-1].+Axxyzz[2:end-1,1:end-1,2:end-1]).*0.5))
- R.=0; @parallel av_zi!(R, Axxyyz); @test all(Array(R .== (Axxyyz[2:end-1,2:end-1,2:end ].+Axxyyz[2:end-1,2:end-1,1:end-1]).*0.5))
- R.=0; @parallel av_xya!(R, Axy); @test all(Array(R .== (Axy[1:end-1,1:end-1,:].+Axy[2:end,1:end-1,:].+Axy[1:end-1,2:end,:].+Axy[2:end,2:end,:])*0.25))
- R.=0; @parallel av_xza!(R, Axz); @test all(Array(R .== (Axz[1:end-1,:,1:end-1].+Axz[2:end,:,1:end-1].+Axz[1:end-1,:,2:end].+Axz[2:end,:,2:end])*0.25))
- R.=0; @parallel av_yza!(R, Ayz); @test all(Array(R .== (Ayz[:,1:end-1,1:end-1].+Ayz[:,2:end,1:end-1].+Ayz[:,1:end-1,2:end].+Ayz[:,2:end,2:end])*0.25))
- R.=0; @parallel av_xyi!(R, Axyzz); @test all(Array(R .== (Axyzz[1:end-1,1:end-1,2:end-1].+Axyzz[2:end,1:end-1,2:end-1].+Axyzz[1:end-1,2:end,2:end-1].+Axyzz[2:end,2:end,2:end-1])*0.25))
- R.=0; @parallel av_xzi!(R, Axyyz); @test all(Array(R .== (Axyyz[1:end-1,2:end-1,1:end-1].+Axyyz[2:end,2:end-1,1:end-1].+Axyyz[1:end-1,2:end-1,2:end].+Axyyz[2:end,2:end-1,2:end])*0.25))
- R.=0; @parallel av_yzi!(R, Axxyz); @test all(Array(R .== (Axxyz[2:end-1,1:end-1,1:end-1].+Axxyz[2:end-1,2:end,1:end-1].+Axxyz[2:end-1,1:end-1,2:end].+Axxyz[2:end-1,2:end,2:end])*0.25))
+ R.=0; @parallel av!(R, Axyz); @test all(Array(R .== (Axyz[1:end-1,1:end-1,1:end-1].+Axyz[2:end,1:end-1,1:end-1].+Axyz[2:end,2:end,1:end-1].+Axyz[2:end,2:end,2:end].+Axyz[1:end-1,2:end,2:end].+Axyz[1:end-1,1:end-1,2:end].+Axyz[2:end,1:end-1,2:end].+Axyz[1:end-1,2:end,1:end-1]).*$precision(0.125)))
+ R.=0; @parallel av_xa!(R, Ax); @test all(Array(R .== (Ax[2:end, :, :].+Ax[1:end-1, :, :]).*$precision(0.5)))
+ R.=0; @parallel av_ya!(R, Ay); @test all(Array(R .== (Ay[ :,2:end, :].+Ay[ :,1:end-1, :]).*$precision(0.5)))
+ R.=0; @parallel av_za!(R, Az); @test all(Array(R .== (Az[ :, :,2:end].+Az[ :, :,1:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_xi!(R, Axyyzz); @test all(Array(R .== (Axyyzz[2:end ,2:end-1,2:end-1].+Axyyzz[1:end-1,2:end-1,2:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_yi!(R, Axxyzz); @test all(Array(R .== (Axxyzz[2:end-1,2:end ,2:end-1].+Axxyzz[2:end-1,1:end-1,2:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_zi!(R, Axxyyz); @test all(Array(R .== (Axxyyz[2:end-1,2:end-1,2:end ].+Axxyyz[2:end-1,2:end-1,1:end-1]).*$precision(0.5)))
+ R.=0; @parallel av_xya!(R, Axy); @test all(Array(R .== (Axy[1:end-1,1:end-1,:].+Axy[2:end,1:end-1,:].+Axy[1:end-1,2:end,:].+Axy[2:end,2:end,:]).*$precision(0.25)))
+ R.=0; @parallel av_xza!(R, Axz); @test all(Array(R .== (Axz[1:end-1,:,1:end-1].+Axz[2:end,:,1:end-1].+Axz[1:end-1,:,2:end].+Axz[2:end,:,2:end]).*$precision(0.25)))
+ R.=0; @parallel av_yza!(R, Ayz); @test all(Array(R .== (Ayz[:,1:end-1,1:end-1].+Ayz[:,2:end,1:end-1].+Ayz[:,1:end-1,2:end].+Ayz[:,2:end,2:end]).*$precision(0.25)))
+ R.=0; @parallel av_xyi!(R, Axyzz); @test all(Array(R .== (Axyzz[1:end-1,1:end-1,2:end-1].+Axyzz[2:end,1:end-1,2:end-1].+Axyzz[1:end-1,2:end,2:end-1].+Axyzz[2:end,2:end,2:end-1]).*$precision(0.25)))
+ R.=0; @parallel av_xzi!(R, Axyyz); @test all(Array(R .== (Axyyz[1:end-1,2:end-1,1:end-1].+Axyyz[2:end,2:end-1,1:end-1].+Axyyz[1:end-1,2:end-1,2:end].+Axyyz[2:end,2:end-1,2:end]).*$precision(0.25)))
+ R.=0; @parallel av_yzi!(R, Axxyz); @test all(Array(R .== (Axxyz[2:end-1,1:end-1,1:end-1].+Axxyz[2:end-1,2:end,1:end-1].+Axxyz[2:end-1,1:end-1,2:end].+Axxyz[2:end-1,2:end,2:end]).*$precision(0.25)))
end;
@testset "harmonic averages" begin
@parallel harm!(R, Axyz) = (@all(R) = @harm(Axyz); return)
@@ -166,4 +182,6 @@ Base.retry_load_extensions() # Potentially needed to load the extensions after t
end;
@reset_parallel_stencil()
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/test_extensions.jl b/test/test_extensions.jl
index b76d5962..c79b7ded 100644
--- a/test/test_extensions.jl
+++ b/test/test_extensions.jl
@@ -1,5 +1,5 @@
using Test
-import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_POLYESTER
+import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
TEST_PACKAGES = SUPPORTED_PACKAGES
TEST_PACKAGES = filter!(x->x≠PKG_POLYESTER, TEST_PACKAGES) # NOTE: Polyester is not tested here, because the CPU case is sufficiently covered by the test of the Threads package.
@static if PKG_CUDA in TEST_PACKAGES
@@ -10,6 +10,17 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
exename = joinpath(Sys.BINDIR, Base.julia_exename())
const TEST_PROJECTS = ["Diffusion3D_minimal"] # ["Diffusion3D_minimal", "Diffusion3D", "Diffusion"]
diff --git a/test/test_incremental_compilation.jl b/test/test_incremental_compilation.jl
index 5982dac8..e7da4fab 100644
--- a/test/test_incremental_compilation.jl
+++ b/test/test_incremental_compilation.jl
@@ -1,5 +1,5 @@
using Test
-import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_POLYESTER
+import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER
TEST_PACKAGES = SUPPORTED_PACKAGES
@static if PKG_CUDA in TEST_PACKAGES
import CUDA
@@ -9,6 +9,14 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
@static if PKG_POLYESTER in TEST_PACKAGES
import Polyester
end
diff --git a/test/test_init_parallel_stencil.jl b/test/test_init_parallel_stencil.jl
index c4d51706..6483cccd 100644
--- a/test/test_init_parallel_stencil.jl
+++ b/test/test_init_parallel_stencil.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, @get_package, @get_numbertype, @get_ndims, @get_inbounds, @get_memopt, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_NONE, NUMBERTYPE_NONE, NDIMS_NONE
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, @get_package, @get_numbertype, @get_ndims, @get_inbounds, @get_memopt, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, NDIMS_NONE
import ParallelStencil: @require, @symbols
import ParallelStencil: extract_posargs_init, extract_kwargs_init, check_already_initialized, set_initialized, is_initialized, check_initialized, set_package, set_numbertype, set_ndims, set_inbounds, set_memopt
using ParallelStencil.Exceptions
@@ -13,6 +13,17 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
@static for package in TEST_PACKAGES eval(:(
diff --git a/test/test_parallel.jl b/test/test_parallel.jl
index 5809cc15..d696865e 100644
--- a/test/test_parallel.jl
+++ b/test/test_parallel.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_POLYESTER, INDICES
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER, INDICES
import ParallelStencil: @require, @prettystring, @gorgeousstring, @isgpu, @iscpu
import ParallelStencil: checkargs_parallel, validate_body, parallel
using ParallelStencil.Exceptions
@@ -15,15 +15,27 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
import ParallelStencil.@gorgeousexpand
-@static for package in TEST_PACKAGES eval(:(
- @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
+const TEST_PRECISIONS = [Float32, Float64]
+@static for package in TEST_PACKAGES
+for precision in TEST_PRECISIONS
+(package == PKG_METAL && precision == Float64) ? continue : nothing # Metal does not support Float64
+
+eval(:(
+ @testset "$(basename(@__FILE__)) (package: $(nameof($package))) (precision: $(nameof($precision)))" begin
@testset "1. parallel macros" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 3)
+ @init_parallel_stencil($package, $precision, 3)
@require @is_initialized()
@testset "@parallel " begin # NOTE: calls must go to ParallelStencil.ParallelKernel.parallel and must therefore give the same result as in ParallelKernel, except for memopt tests (tests copied 1-to-1 from there).
@static if $package == $PKG_CUDA
@@ -212,13 +224,13 @@ import ParallelStencil.@gorgeousexpand
end
@testset "@parallel (3D; on-the-fly)" begin
nx, ny, nz = 32, 8, 8
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
Ci = @ones(nx, ny, nz);
copy!(T, [ix + (iy-1)*size(T,1) + (iz-1)*size(T,1)*size(T,2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)].^3);
- @parallel function diffusion3D_step!(T2, T, Ci, lam::Data.Number, dt::Float64, _dx, _dy, _dz)
+ @parallel function diffusion3D_step!(T2, T, Ci, lam::Data.Number, dt::$precision, _dx, _dy, _dz)
@all(qx) = -lam*@d_xi(T)*_dx # Fourier's law of heat conduction
@all(qy) = -lam*@d_yi(T)*_dy # ...
@all(qz) = -lam*@d_zi(T)*_dz # ...
@@ -234,7 +246,7 @@ import ParallelStencil.@gorgeousexpand
);
@test all(Array(T2) .== Array(T2_ref))
end
- @static if $package in [$PKG_CUDA, $PKG_AMDGPU]
+ @static if $package in [$PKG_CUDA, $PKG_AMDGPU] # TODO add support for Metal
@testset "@parallel memopt (nx, ny, nz = x .* threads)" begin # NOTE: the following does not work for some reason: (nx, ny, nz = ($nx, $ny, $nz))" for (nx, ny, nz) in ((32, 8, 9), (32, 8, 8), (31, 7, 9), (33, 9, 9), (33, 7, 8))
nx, ny, nz = 32, 8, 8
# threads = (8, 4, 1)
@@ -269,12 +281,12 @@ import ParallelStencil.@gorgeousexpand
copy!(A, [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function d2_memopt!(A2, A)
if (iz>1 && iz (3D, memopt, stencilranges=(0:0, -1:1, 0:0); y-stencil)" begin
@@ -284,12 +296,12 @@ import ParallelStencil.@gorgeousexpand
copy!(A, [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true function d2_memopt!(A2, A)
if (iy>1 && iy (3D, memopt, stencilranges=(1:1, 1:1, 0:2); z-stencil)" begin
@@ -302,7 +314,7 @@ import ParallelStencil.@gorgeousexpand
return
end
@parallel memopt=true d2_memopt!(A2, A);
- A2_ref[2:end-1,2:end-1,2:end-1] .= A[2:end-1,2:end-1,3:end] .- 2.0.*A[2:end-1,2:end-1,2:end-1] .+ A[2:end-1,2:end-1,1:end-2];
+ A2_ref[2:end-1,2:end-1,2:end-1] .= (A[2:end-1,2:end-1,3:end] .- A[2:end-1,2:end-1,2:end-1]) .- (A[2:end-1,2:end-1,2:end-1] .- A[2:end-1,2:end-1,1:end-2]);
@test all(Array(A2) .== Array(A2_ref))
end
@testset "@parallel (3D, memopt, stencilranges=(1:1, 0:2, 1:1); y-stencil)" begin
@@ -315,11 +327,11 @@ import ParallelStencil.@gorgeousexpand
return
end
@parallel memopt=true d2_memopt!(A2, A);
- A2_ref[2:end-1,2:end-1,2:end-1] .= A[2:end-1,3:end,2:end-1] .- 2.0.*A[2:end-1,2:end-1,2:end-1] .+ A[2:end-1,1:end-2,2:end-1];
+ A2_ref[2:end-1,2:end-1,2:end-1] .= (A[2:end-1,3:end,2:end-1] .- A[2:end-1,2:end-1,2:end-1]) .- (A[2:end-1,2:end-1,2:end-1] .- A[2:end-1,1:end-2,2:end-1]);
@test all(Array(A2) .== Array(A2_ref))
end
@testset "@parallel_indices (3D, memopt, stencilranges=-1:1)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -344,7 +356,7 @@ import ParallelStencil.@gorgeousexpand
@test all(Array(T2) .== Array(T2_ref))
end
@testset "@parallel (3D, memopt, stencilranges=0:2)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = 1
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -369,22 +381,22 @@ import ParallelStencil.@gorgeousexpand
copy!(A, [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function higher_order_memopt!(A2, A)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, A);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
@test all(Array(A2) .== Array(A2_ref))
end
@testset "@parallel (3D, memopt, stencilranges=0:2; on-the-fly)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
Ci = @ones(nx, ny, nz);
copy!(T, [ix + (iy-1)*size(T,1) + (iz-1)*size(T,1)*size(T,2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)].^3);
- @parallel memopt=true loopsize=3 function diffusion3D_step!(T2, T, Ci, lam::Data.Number, dt::Float64, _dx, _dy, _dz)
+ @parallel memopt=true loopsize=3 function diffusion3D_step!(T2, T, Ci, lam::Data.Number, dt::$precision, _dx, _dy, _dz)
@all(qx) = -lam*@d_xi(T)*_dx # Fourier's law of heat conduction
@all(qy) = -lam*@d_yi(T)*_dy # ...
@all(qz) = -lam*@d_zi(T)*_dz # ...
@@ -422,12 +434,12 @@ import ParallelStencil.@gorgeousexpand
copy!(B, 2 .* [ix + (iy-1)*size(B,1) + (iz-1)*size(B,1)*size(B,2) for ix=1:size(B,1), iy=1:size(B,2), iz=1:size(B,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function d2_memopt!(A2, A, B)
if (iz>1 && iz (3D, memopt; 2 arrays, y-stencil)" begin
@@ -439,12 +451,12 @@ import ParallelStencil.@gorgeousexpand
copy!(B, 2 .* [ix + (iy-1)*size(B,1) + (iz-1)*size(B,1)*size(B,2) for ix=1:size(B,1), iy=1:size(B,2), iz=1:size(B,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function d2_memopt!(A2, A, B)
if (iy>1 && iy (3D, memopt; 2 arrays, x-stencil)" begin
@@ -456,16 +468,16 @@ import ParallelStencil.@gorgeousexpand
copy!(B, 2 .* [ix + (iy-1)*size(B,1) + (iz-1)*size(B,1)*size(B,2) for ix=1:size(B,1), iy=1:size(B,2), iz=1:size(B,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true function d2_memopt!(A2, A, B)
if (ix>1 && ix (3D, memopt; 2 arrays, x-y-z- + z-stencil)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -485,7 +497,7 @@ import ParallelStencil.@gorgeousexpand
@test all(Array(T2) .== Array(T2_ref))
end
@testset "@parallel (3D, memopt; 2 arrays, x-y-z- + x-stencil)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -505,7 +517,7 @@ import ParallelStencil.@gorgeousexpand
@test all(Array(T2) .== Array(T2_ref))
end
@testset "@parallel (3D, memopt; 3 arrays, x-y-z- + y- + x-stencil)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -541,20 +553,20 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix-11 && iy+2<=size(B2,2) && iz-2>=1 && iz+3<=size(B2,3))
- B2[ix-1,iy+2,iz] = B[ix-1,iy+2,iz+3] - 2.0*B[ix-3,iy+2,iz] + B[ix-4,iy+2,iz-2]
+ B2[ix-1,iy+2,iz] = B[ix-1,iy+2,iz+3] - 2*B[ix-3,iy+2,iz] + B[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix-11 && iy+2<=size(C2,2) && iz-2>=1 && iz+3<=size(C2,3))
- C2[ix-1,iy+2,iz] = C[ix-1,iy+2,iz+3] - 2.0*C[ix-3,iy+2,iz] + C[ix-4,iy+2,iz-2]
+ C2[ix-1,iy+2,iz] = C[ix-1,iy+2,iz+3] - 2*C[ix-3,iy+2,iz] + C[ix-4,iy+2,iz-2]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[5:end-1,3:end,3:end-3] .= B[5:end-1,3:end,6:end] .- 2.0.*B[3:end-3,3:end,3:end-3] .+ B[2:end-4,3:end,1:end-5];
- C2_ref[5:end-1,3:end,3:end-3] .= C[5:end-1,3:end,6:end] .- 2.0.*C[3:end-3,3:end,3:end-3] .+ C[2:end-4,3:end,1:end-5];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[5:end-1,3:end,3:end-3] .= B[5:end-1,3:end,6:end] .- 2*B[3:end-3,3:end,3:end-3] .+ B[2:end-4,3:end,1:end-5];
+ C2_ref[5:end-1,3:end,3:end-3] .= C[5:end-1,3:end,6:end] .- 2*C[3:end-3,3:end,3:end-3] .+ C[2:end-4,3:end,1:end-5];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -574,20 +586,20 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix-11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix-1,iy+2,iz+1] = B[ix-1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix-1,iy+2,iz+1] = B[ix-1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-4>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-3,iy+2,iz-1] + C[ix-4,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-3,iy+2,iz-1] + C[ix-4,iy+2,iz-1]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[5:end-1,3:end,2:end-1] .= B[5:end-1,3:end,3:end] .- 2.0.*B[3:end-3,3:end,2:end-1] .+ B[2:end-4,3:end,2:end-1];
- C2_ref[5:end-1,3:end,1:end-1] .= C[5:end-1,3:end,2:end] .- 2.0.*C[3:end-3,3:end,1:end-1] .+ C[2:end-4,3:end,1:end-1];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[5:end-1,3:end,2:end-1] .= B[5:end-1,3:end,3:end] .- 2*B[3:end-3,3:end,2:end-1] .+ B[2:end-4,3:end,2:end-1];
+ C2_ref[5:end-1,3:end,1:end-1] .= C[5:end-1,3:end,2:end] .- 2*C[3:end-3,3:end,1:end-1] .+ C[2:end-4,3:end,1:end-1];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -607,20 +619,20 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
@parallel_indices (ix,iy,iz) memopt=true loopsize=3 function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2.0.*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
- C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2.0.*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
+ C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -640,13 +652,13 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
kernel = @gorgeousstring @parallel_indices (ix,iy,iz) memopt=true optvars=(A, C) loopdim=3 loopsize=3 optranges=(A=(-4:-1, 2:2, -2:3), B=(-4:1, 2:2, 1:2), C=(-1:-1, 2:2, -1:0)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@@ -654,28 +666,30 @@ import ParallelStencil.@gorgeousexpand
@test occursin("loopoffset = ((CUDA.blockIdx()).z - 1) * 3", kernel)
elseif $package == $PKG_AMDGPU
@test occursin("loopoffset = ((AMDGPU.workgroupIdx()).z - 1) * 3", kernel)
+ elseif $package == $PKG_METAL
+ @test occursin("loopoffset = ((Metal.threadgroup_position_in_grid_3d()).z - 1) * 3", kernel)
end
@test occursin("for i = -4:3", kernel)
@test occursin("tz = i + loopoffset", kernel)
- @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2.0A_ixm3_iyp2_iz) + A_ixm4_iyp2_izm2", kernel)
- @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2.0 * B[ix - 3, iy + 2, iz + 1]) + B[ix - 4, iy + 2, iz + 1]", kernel)
- @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C_ixm1_iyp2_iz - 2.0C_ixm1_iyp2_izm1) + C_ixm1_iyp2_izm1", kernel)
+ @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2A_ixm3_iyp2_iz) + A_ixm4_iyp2_izm2", kernel)
+ @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2 * B[ix - 3, iy + 2, iz + 1]) + B[ix - 4, iy + 2, iz + 1]", kernel)
+ @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C_ixm1_iyp2_iz - 2C_ixm1_iyp2_izm1) + C_ixm1_iyp2_izm1", kernel)
@parallel_indices (ix,iy,iz) memopt=true optvars=(A, C) loopdim=3 loopsize=3 optranges=(A=(-4:-1, 2:2, -2:3), B=(-4:1, 2:2, 1:2), C=(-1:-1, 2:2, -1:0)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2.0.*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
- C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2.0.*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
+ C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -695,13 +709,13 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
kernel = @gorgeousstring @parallel_indices (ix,iy,iz) memopt=true optvars=(A, C) loopdim=3 loopsize=3 optranges=(A=(-4:-1, 2:2, -2:3), B=(-4:1, 2:2, 1:2), C=(-1:-1, 2:2, -1:0)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@@ -709,28 +723,30 @@ import ParallelStencil.@gorgeousexpand
@test occursin("loopoffset = ((CUDA.blockIdx()).z - 1) * 3", kernel)
elseif $package == $PKG_AMDGPU
@test occursin("loopoffset = ((AMDGPU.workgroupIdx()).z - 1) * 3", kernel)
+ elseif $package == $PKG_METAL
+ @test occursin("loopoffset = ((Metal.threadgroup_position_in_grid_3d()).z - 1) * 3", kernel)
end
@test occursin("for i = -4:3", kernel)
@test occursin("tz = i + loopoffset", kernel)
- @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2.0A_ixm3_iyp2_iz) + A_ixm4_iyp2_izm2", kernel)
- @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2.0 * B[ix - 3, iy + 2, iz + 1]) + B[ix - 4, iy + 2, iz + 1]", kernel)
- @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C_ixm1_iyp2_iz - 2.0C_ixm1_iyp2_izm1) + C_ixm1_iyp2_izm1", kernel)
+ @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2A_ixm3_iyp2_iz) + A_ixm4_iyp2_izm2", kernel)
+ @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2 * B[ix - 3, iy + 2, iz + 1]) + B[ix - 4, iy + 2, iz + 1]", kernel)
+ @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C_ixm1_iyp2_iz - 2C_ixm1_iyp2_izm1) + C_ixm1_iyp2_izm1", kernel)
@parallel_indices (ix,iy,iz) memopt=true optvars=(A, C) loopdim=3 loopsize=3 optranges=(A=(-4:-1, 2:2, -2:3), B=(-4:1, 2:2, 1:2), C=(-1:-1, 2:2, -1:0)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2.0.*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
- C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2.0.*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
+ C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -750,35 +766,35 @@ import ParallelStencil.@gorgeousexpand
copy!(C, 3 .* [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)].^3);
kernel = @gorgeousstring @parallel_indices (ix,iy,iz) memopt=true optvars=(A, B) loopdim=3 loopsize=3 optranges=(A=(-1:-1, 2:2, -2:3), B=(-4:-3, 2:2, 1:1)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
- @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2.0 * A[ix - 3, iy + 2, iz]) + A[ix - 4, iy + 2, iz - 2]", kernel)
- @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2.0B_ixm3_iyp2_izp1) + B_ixm4_iyp2_izp1", kernel) # NOTE: when z is restricted to 1:1 then x cannot include +1, as else the x-y range does not include any z (result: IncoherentArgumentError: incoherent argument in memopt: optranges in z dimension do not include any array access.).
- @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C[ix - 1, iy + 2, iz] - 2.0 * C[ix - 1, iy + 2, iz - 1]) + C[ix - 1, iy + 2, iz - 1]", kernel)
+ @test occursin("A2[ix - 1, iy + 2, iz] = (A_ixm1_iyp2_izp3 - 2 * A[ix - 3, iy + 2, iz]) + A[ix - 4, iy + 2, iz - 2]", kernel)
+ @test occursin("B2[ix + 1, iy + 2, iz + 1] = (B[ix + 1, iy + 2, iz + 2] - 2B_ixm3_iyp2_izp1) + B_ixm4_iyp2_izp1", kernel) # NOTE: when z is restricted to 1:1 then x cannot include +1, as else the x-y range does not include any z (result: IncoherentArgumentError: incoherent argument in memopt: optranges in z dimension do not include any array access.).
+ @test occursin("C2[ix - 1, iy + 2, iz - 1] = (C[ix - 1, iy + 2, iz] - 2 * C[ix - 1, iy + 2, iz - 1]) + C[ix - 1, iy + 2, iz - 1]", kernel)
@parallel_indices (ix,iy,iz) memopt=true optvars=(A, B) loopdim=3 loopsize=3 optranges=(A=(-1:-1, 2:2, -2:3), B=(-4:-3, 2:2, 1:1)) function higher_order_memopt!(A2, B2, C2, A, B, C)
if (ix-4>1 && ix-11 && iy+2<=size(A2,2) && iz-2>=1 && iz+3<=size(A2,3))
- A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2.0*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
+ A2[ix-1,iy+2,iz] = A[ix-1,iy+2,iz+3] - 2*A[ix-3,iy+2,iz] + A[ix-4,iy+2,iz-2]
end
if (ix-4>1 && ix+11 && iy+2<=size(B2,2) && iz+1>=1 && iz+2<=size(B2,3))
- B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2.0*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
+ B2[ix+1,iy+2,iz+1] = B[ix+1,iy+2,iz+2] - 2*B[ix-3,iy+2,iz+1] + B[ix-4,iy+2,iz+1]
end
if (ix-1>1 && ix-11 && iy+2<=size(C2,2) && iz-1>=1 && iz<=size(C2,3))
- C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2.0*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
+ C2[ix-1,iy+2,iz-1] = C[ix-1,iy+2,iz] - 2*C[ix-1,iy+2,iz-1] + C[ix-1,iy+2,iz-1]
end
return
end
@parallel memopt=true higher_order_memopt!(A2, B2, C2, A, B, C);
- A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2.0.*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
- B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2.0.*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
- C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2.0.*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
+ A2_ref[5:end-1,3:end,3:end-3] .= A[5:end-1,3:end,6:end] .- 2*A[3:end-3,3:end,3:end-3] .+ A[2:end-4,3:end,1:end-5];
+ B2_ref[7:end-1,3:end,2:end-1] .= B[7:end-1,3:end,3:end] .- 2*B[3:end-5,3:end,2:end-1] .+ B[2:end-6,3:end,2:end-1];
+ C2_ref[2:end-1,3:end,1:end-1] .= C[2:end-1,3:end,2:end] .- 2*C[2:end-1,3:end,1:end-1] .+ C[2:end-1,3:end,1:end-1];
@test all(Array(A2) .== Array(A2_ref))
@test all(Array(B2) .== Array(B2_ref))
@test all(Array(C2) .== Array(C2_ref))
@@ -812,7 +828,7 @@ import ParallelStencil.@gorgeousexpand
@test all(Array(A2) .== Array(A))
end
@testset "@parallel (3D, memopt, stencilranges=0:2)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -831,7 +847,7 @@ import ParallelStencil.@gorgeousexpand
@test all(Array(T2) .== Array(T2_ref))
end
@testset "@parallel (3D, memopt; 3 arrays, x-y-z- + y- + x-stencil)" begin
- lam=dt=_dx=_dy=_dz = 1.0
+ lam=dt=_dx=_dy=_dz = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -864,12 +880,12 @@ import ParallelStencil.@gorgeousexpand
end;
@testset "2. parallel macros (2D)" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 3)
+ @init_parallel_stencil($package, $precision, 2)
@require @is_initialized()
- @static if $package in [$PKG_CUDA, $PKG_AMDGPU]
+ @static if $package in [$PKG_CUDA, $PKG_AMDGPU] # TODO add support for Metal
nx, ny, nz = 32, 8, 1
@testset "@parallel_indices (2D, memopt, stencilranges=(-1:1,-1:1,0:0))" begin
- lam=dt=_dx=_dy = 1.0
+ lam=dt=_dx=_dy = $precision(1)
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
T2_ref = @zeros(nx, ny, nz);
@@ -897,7 +913,7 @@ import ParallelStencil.@gorgeousexpand
@testset "3. global defaults" begin
@testset "inbounds=true" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 1, inbounds=true)
+ @init_parallel_stencil($package, $precision, 1, inbounds=true)
@require @is_initialized
expansion = @prettystring(1, @parallel_indices (ix) inbounds=true f(A) = (2*A; return))
@test occursin("Base.@inbounds begin", expansion)
@@ -909,40 +925,43 @@ import ParallelStencil.@gorgeousexpand
end;
@testset "@parallel_indices (I...) (1D)" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 1)
+ @init_parallel_stencil($package, $precision, 1)
@require @is_initialized
A = @zeros(4*5*6)
- @parallel_indices (I...) function write_indices!(A)
- A[I...] = sum((I .- (1,)) .* (1.0));
+ one = $precision(1)
+ @parallel_indices (I...) function write_indices!(A, one)
+ A[I...] = sum((I .- (1,)) .* (one));
return
end
- @parallel write_indices!(A);
+ @parallel write_indices!(A, one);
@test all(Array(A) .== [(ix-1) for ix=1:size(A,1)])
@reset_parallel_stencil()
end;
@testset "@parallel_indices (I...) (2D)" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 2)
+ @init_parallel_stencil($package, $precision, 2)
@require @is_initialized
A = @zeros(4, 5*6)
- @parallel_indices (I...) function write_indices!(A)
- A[I...] = sum((I .- (1,)) .* (1.0, size(A,1)));
+ one = $precision(1)
+ @parallel_indices (I...) function write_indices!(A, one)
+ A[I...] = sum((I .- (1,)) .* (one, size(A,1)));
return
end
- @parallel write_indices!(A);
+ @parallel write_indices!(A, one);
@test all(Array(A) .== [(ix-1) + (iy-1)*size(A,1) for ix=1:size(A,1), iy=1:size(A,2)])
@reset_parallel_stencil()
end;
@testset "@parallel_indices (I...) (3D)" begin
@require !@is_initialized()
- @init_parallel_stencil($package, Float64, 3)
+ @init_parallel_stencil($package, $precision, 3)
@require @is_initialized
A = @zeros(4, 5, 6)
- @parallel_indices (I...) function write_indices!(A)
- A[I...] = sum((I .- (1,)) .* (1.0, size(A,1), size(A,1)*size(A,2)));
+ one = $precision(1)
+ @parallel_indices (I...) function write_indices!(A, one)
+ A[I...] = sum((I .- (1,)) .* (one, size(A,1), size(A,1)*size(A,2)));
return
end
- @parallel write_indices!(A);
+ @parallel write_indices!(A, one);
@test all(Array(A) .== [(ix-1) + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)])
@reset_parallel_stencil()
end;
@@ -1042,7 +1061,7 @@ import ParallelStencil.@gorgeousexpand
@reset_parallel_stencil()
end;
@testset "5. Exceptions" begin
- @init_parallel_stencil($package, Float64, 3)
+ @init_parallel_stencil($package, $precision, 3)
@require @is_initialized
@testset "arguments @parallel" begin
@test_throws ArgumentError checkargs_parallel(); # Error: isempty(args)
@@ -1059,4 +1078,6 @@ import ParallelStencil.@gorgeousexpand
@reset_parallel_stencil()
end;
end;
-)) end == nothing || true;
+))
+
+end end == nothing || true;
diff --git a/test/test_projects/Diffusion3D_minimal/test/localtest_diffusion_Metal.jl b/test/test_projects/Diffusion3D_minimal/test/localtest_diffusion_Metal.jl
new file mode 100644
index 00000000..2f2df9e0
--- /dev/null
+++ b/test/test_projects/Diffusion3D_minimal/test/localtest_diffusion_Metal.jl
@@ -0,0 +1,8 @@
+push!(LOAD_PATH, "@stdlib") # NOTE: this is needed to enable this test to run from the Pkg manager
+push!(LOAD_PATH, joinpath(@__DIR__, ".."))
+using Test
+using Pkg
+Pkg.activate(joinpath(@__DIR__, ".."))
+Pkg.instantiate()
+import Metal
+using Diffusion3D_minimal
diff --git a/test/test_reset_parallel_stencil.jl b/test/test_reset_parallel_stencil.jl
index 481e6b52..08b66da5 100644
--- a/test/test_reset_parallel_stencil.jl
+++ b/test/test_reset_parallel_stencil.jl
@@ -1,6 +1,6 @@
using Test
using ParallelStencil
-import ParallelStencil: @reset_parallel_stencil, @is_initialized, @get_package, @get_numbertype, @get_ndims, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_NONE, NUMBERTYPE_NONE, NDIMS_NONE
+import ParallelStencil: @reset_parallel_stencil, @is_initialized, @get_package, @get_numbertype, @get_ndims, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_POLYESTER, PKG_NONE, NUMBERTYPE_NONE, NDIMS_NONE
import ParallelStencil: @require, @symbols
TEST_PACKAGES = SUPPORTED_PACKAGES
@static if PKG_CUDA in TEST_PACKAGES
@@ -11,6 +11,17 @@ end
import AMDGPU
if !AMDGPU.functional() TEST_PACKAGES = filter!(x->x≠PKG_AMDGPU, TEST_PACKAGES) end
end
+@static if PKG_METAL in TEST_PACKAGES
+ @static if Sys.isapple()
+ import Metal
+ if !Metal.functional() TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES) end
+ else
+ TEST_PACKAGES = filter!(x->x≠PKG_METAL, TEST_PACKAGES)
+ end
+end
+@static if PKG_POLYESTER in TEST_PACKAGES
+ import Polyester
+end
Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered.
@static for package in TEST_PACKAGES eval(:(